import gstaichi as ti
import torch
import genesis as gs
from genesis.utils import array_class
from genesis.utils.misc import DeprecationError
from genesis.repr_base import RBC
[docs]class RigidJoint(RBC):
"""
Joint class for rigid body entities. Each RigidLink is connected to its parent link via a RigidJoint.
"""
def __init__(
self,
entity,
name,
idx,
link_idx,
q_start,
dof_start,
n_qs,
n_dofs,
type,
pos,
quat,
init_qpos,
sol_params,
dofs_motion_ang,
dofs_motion_vel,
dofs_limit,
dofs_invweight,
dofs_frictionloss,
dofs_stiffness,
dofs_damping,
dofs_armature,
dofs_kp,
dofs_kv,
dofs_force_range,
):
self._name = name
self._entity = entity
self._solver = entity.solver
self._uid = gs.UID()
self._idx = idx
self._link_idx = link_idx
self._q_start = q_start
self._dof_start = dof_start
self._n_qs = n_qs
self._n_dofs = n_dofs
self._type = type
self._pos = pos
self._quat = quat
self._init_qpos = init_qpos
self._sol_params = sol_params
self._dofs_motion_ang = dofs_motion_ang
self._dofs_motion_vel = dofs_motion_vel
self._dofs_limit = dofs_limit
self._dofs_invweight = dofs_invweight
self._dofs_frictionloss = dofs_frictionloss
self._dofs_stiffness = dofs_stiffness
self._dofs_damping = dofs_damping
self._dofs_armature = dofs_armature
self._dofs_kp = dofs_kp
self._dofs_kv = dofs_kv
self._dofs_force_range = dofs_force_range
def __getattr__(self, name):
# Must be implemented to throw deprecation warnings when accessing old properties, ignoring introspection
for name_old, name_new in (
("dof_idx", "dofs_idx"),
("dof_idx_local", "dofs_idx_local"),
("q_idx", "qs_idx"),
("q_idx_local", "qs_idx_local"),
):
if name == name_old:
gs.logger.warning(
f"This property is deprecated and will be removed in future release. Please use '{name_new}' instead."
)
getter = getattr(self, f"_{name_old}")
return getter()
raise AttributeError
# ------------------------------------------------------------------------------------
# -------------------------------- real-time state -----------------------------------
# ------------------------------------------------------------------------------------
[docs] def get_pos(self):
"""
Get the position of the joint in the world frame.
"""
raise DeprecationError(
"This method has been removed. Please consider operating at link-level to get the cartesian position in "
"word frame. Alternatively, 'get_anchor_pos' returns the anchor position of the joint in the world frame."
)
[docs] def get_quat(self):
"""
Get the quaternion of the joint in the world frame.
"""
raise DeprecationError(
"This method has been removed. Please consider operating at link-level to get the cartesian orientation in "
"word frame. Alternatively, 'get_anchor_axis' returns the anchor axis of the joint in the world frame."
)
[docs] @gs.assert_built
def get_anchor_pos(self):
"""
Get the anchor position of the joint in the world frame.
Mathematically, the anchor point corresponds to the point that is fixed wrt parent link and is coincident with
the joint for the neutral configuration qpos0. This means that this point moves under the effect of the
generalized coordinates corresponding to this joint (and all its ancestors in the kinematic tree). Physically,
the anchor point is the "output" of the joint transmission, on which the child body is welded.
"""
tensor = torch.empty((self._solver._B, 3), dtype=gs.tc_float, device=gs.device)
_kernel_get_anchor_pos(self._idx, tensor, self._solver.joints_state)
if self._solver.n_envs == 0:
tensor = tensor[0]
return tensor
[docs] @gs.assert_built
def get_anchor_axis(self):
"""
Get the anchor axis of the joint in the world frame.
See `RigidJoint.get_anchor_pos` documentation for details about the notion on anchor point.
"""
tensor = torch.empty((self._solver._B, 3), dtype=gs.tc_float, device=gs.device)
_kernel_get_anchor_axis(self._idx, tensor, self._solver.joints_state)
if self._solver.n_envs == 0:
tensor = tensor[0]
return tensor
[docs] def set_sol_params(self, sol_params):
"""
Set the solver parameters of this joint.
"""
if self._solver.is_built:
self._solver.set_sol_params(sol_params, joints_idx=self._idx, envs_idx=None)
else:
self._sol_params = sol_params
@property
def sol_params(self):
"""
Returns the solver parameters of the joint.
"""
if self._solver.is_built:
return self._solver.get_sol_params(joints_idx=self._idx, envs_idx=None)[..., 0, :]
return self._sol_params
# ------------------------------------------------------------------------------------
# ----------------------------------- properties -------------------------------------
# ------------------------------------------------------------------------------------
@property
def uid(self):
"""
Returns the unique id of the joint.
"""
return self._uid
@property
def name(self):
"""
Returns the name of the joint.
"""
return self._name
@property
def entity(self):
"""
Returns the entity that the joint belongs to.
"""
return self._entity
@property
def solver(self):
"""
The RigidSolver object that the joint belongs to.
"""
return self._solver
@property
def link(self):
"""
Returns the child link that of the joint.
"""
return self._solver.links[self._link_idx]
@property
def idx(self):
"""
Returns the global index of the joint in the rigid solver.
"""
return self._idx
@property
def idx_local(self):
"""
Returns the local index of the joint in the entity.
"""
return self._idx - self._entity.joint_start
@property
def init_qpos(self):
"""
Returns the initial joint position.
"""
return self._init_qpos
@property
def n_qs(self):
"""
Returns the number of `q` (generalized coordinate) variables that the joint has.
"""
return self._n_qs
@property
def n_dofs(self):
"""
Returns the number of dofs that the joint has.
"""
return self._n_dofs
@property
def type(self):
"""
Returns the type of the joint.
"""
return self._type
@property
def pos(self):
"""
Returns the initial position of the joint in the world frame.
"""
return self._pos
@property
def quat(self):
"""
Returns the initial quaternion of the joint in the world frame.
"""
return self._quat
@property
def q_start(self):
"""
Returns the starting index of the `q` variables of the joint in the rigid solver.
"""
return self._q_start
@property
def dof_start(self):
"""
Returns the starting index of the dofs of the joint in the rigid solver.
"""
return self._dof_start
@property
def q_end(self):
"""
Returns the ending index of the `q` variables of the joint in the rigid solver.
"""
return self._n_qs + self.q_start
@property
def dof_end(self):
"""
Returns the ending index of the dofs of the joint in the rigid solver.
"""
return self._n_dofs + self.dof_start
def _dof_idx(self):
"""
Returns all the Degrees' of Freedom (DoF) indices of the joint in the rigid solver.
This property either returns a list, an integer, or None depending on whether the joint has multiple DoFs, a
single one, or none, respectively.
"""
if self.n_dofs == 1:
return self.dof_start
if self.n_dofs == 0:
return None
return self.dofs_idx
@property
def dofs_idx(self):
"""
Returns all the Degrees' of Freedom (DoF) indices of the joint in the rigid solver as a sequence.
"""
return list(range(self.dof_start, self.dof_end))
def _dof_idx_local(self):
"""
Returns the local dof index of the joint in the entity.
This property either returns a list, an integer, or None depending on whether the joint has multiple DoFs, a
single one, or none, respectively.
"""
if self.n_dofs == 1:
return self.dof_start - self._entity.dof_start
if self.n_dofs == 0:
return None
return self.dofs_idx_local
@property
def dofs_idx_local(self):
"""
Returns the local Degrees of Freedom indices of the joint in the entity.
"""
return list(range(self.dof_start - self._entity.dof_start, self.dof_end - self._entity.dof_start))
def _q_idx(self):
"""
Returns all the position indices of the joint in the rigid solver.
This property either returns a list, an integer, or None depending on whether the joint has multiple position
indices, a single one, or none, respectively.
"""
if self.n_qs == 1:
return self.q_start
elif self.n_qs == 0:
return None
else:
return self.qs_idx
@property
def qs_idx(self):
"""
Returns all the position indices of the joint in the rigid solver.
"""
return list(range(self.q_start, self.q_end))
def _q_idx_local(self):
"""
Returns all the local `q` indices of the joint in the entity.
"""
if self.n_qs == 1:
return self.q_start - self._entity.q_start
elif self.n_qs == 0:
return None
else:
return self.qs_idx_local
@property
def qs_idx_local(self):
"""
Returns all the local `q` indices of the joint in the entity.
"""
return list(range(self.q_start - self._entity.q_start, self.q_end - self._entity.q_start))
@property
def dofs_motion_ang(self):
return self._dofs_motion_ang
@property
def dofs_motion_vel(self):
return self._dofs_motion_vel
@property
def dofs_limit(self):
"""
Returns the range limit of the dofs of the joint.
"""
return self._dofs_limit
@property
def dofs_invweight(self):
"""
Returns the invweight of the dofs of the joint.
"""
return self._dofs_invweight
@property
def dofs_frictionloss(self):
"""
Returns the frictionloss of the dofs of the joint.
"""
return self._dofs_frictionloss
@property
def dofs_stiffness(self):
"""
Returns the stiffness of the dofs of the joint.
"""
return self._dofs_stiffness
@property
def dofs_damping(self):
"""
Returns the damping of the dofs of the joint.
"""
return self._dofs_damping
@property
def dofs_armature(self):
"""
Returns the armature of the dofs of the joint.
"""
return self._dofs_armature
@property
def dofs_kp(self):
"""
Returns the kp (positional gain) of the dofs of the joint.
"""
return self._dofs_kp
@property
def dofs_kv(self):
"""
Returns the kv (velocity gain) of the dofs of the joint.
"""
return self._dofs_kv
@property
def dofs_force_range(self):
"""
Returns the force range of the dofs of the joint.
"""
return self._dofs_force_range
@property
def is_built(self):
"""
Returns whether the entity the joint belongs to is built.
"""
return self.entity.is_built
# ------------------------------------------------------------------------------------
# -------------------------------------- repr ----------------------------------------
# ------------------------------------------------------------------------------------
def _repr_brief(self):
return f"{(self._repr_type())}: {self._uid}, name: '{self._name}', idx: {self._idx}, type: {self._type}"
@ti.kernel
def _kernel_get_anchor_pos(joint_idx: ti.i32, tensor: ti.types.ndarray(), joints_state: array_class.JointsState):
_B = joints_state.xanchor.shape[1]
for i_b in range(_B):
xpos = joints_state.xanchor[joint_idx, i_b]
for i in ti.static(range(3)):
tensor[i_b, i] = xpos[i]
@ti.kernel
def _kernel_get_anchor_axis(joint_idx: ti.i32, tensor: ti.types.ndarray(), joints_state: array_class.JointsState):
_B = joints_state.xaxis.shape[1]
for i_b in range(_B):
xaxis = joints_state.xaxis[joint_idx, i_b]
for i in ti.static(range(3)):
tensor[i_b, i] = xaxis[i]