from functools import wraps
from pathlib import Path
import igl
import numpy as np
import gstaichi as ti
import torch
import genesis as gs
import genesis.utils.element as eu
import genesis.utils.geom as gu
import genesis.utils.mesh as mu
from genesis.engine.entities.rigid_entity import RigidLink
from genesis.engine.couplers import SAPCoupler
from genesis.engine.states.cache import QueriedStates
from genesis.engine.states.entities import FEMEntityState
from genesis.utils.misc import to_gs_tensor, tensor_to_array, broadcast_tensor
from .base_entity import Entity
def assert_muscle(method):
@wraps(method)
def wrapper(self, *args, **kwargs):
if not isinstance(self.material, gs.materials.FEM.Muscle):
gs.raise_exception("This method is only supported by entities with 'FEM.Muscle' material.")
return method(self, *args, **kwargs)
return wrapper
[docs]@ti.data_oriented
class FEMEntity(Entity):
"""
A finite element method (FEM)-based entity for deformable simulation.
This class represents a deformable object using tetrahedral elements. It interfaces with
the physics solver to handle state updates, checkpointing, gradients, and actuation
for physics-based simulation in batched environments.
Parameters
----------
scene : Scene
The simulation scene that this entity belongs to.
solver : Solver
The physics solver instance used for simulation.
material : Material
The material properties defining elasticity, density, etc.
morph : Morph
The morph specification that defines the entity's shape.
surface : Surface
The surface mesh associated with the entity (for rendering or collision).
idx : int
Unique identifier of the entity within the scene.
v_start : int, optional
Starting index of this entity's vertices in the global vertex array (default is 0).
el_start : int, optional
Starting index of this entity's elements in the global element array (default is 0).
s_start : int, optional
Starting index of this entity's surface triangles in the global surface array (default is 0).
"""
def __init__(
self, scene, solver, material, morph, surface, idx, v_start=0, el_start=0, s_start=0, name: str | None = None
):
super().__init__(idx, scene, morph, solver, material, surface, name=name)
self._v_start = v_start # offset for vertex index of elements
self._el_start = el_start # offset for element index
self._s_start = s_start # offset for surface triangles
self._step_global_added = None
self._surface.update_texture()
self.sample()
# Check if this is cloth (elements are already triangles)
from genesis.engine.materials.FEM.cloth import Cloth as ClothMaterial
is_cloth = isinstance(self.material, ClothMaterial)
if is_cloth:
# For cloth, elements are already surface triangles
self._surface_tri_np = self.elems
self._n_surfaces = len(self._surface_tri_np)
if self._n_surfaces > 0:
self._n_surface_vertices = len(np.unique(self._surface_tri_np))
else:
self._n_surface_vertices = 0
# For cloth, each triangle is its own "element"
self._surface_el_np = np.arange(self.elems.shape[0], dtype=gs.np_int)
else:
# For volumetric FEM, extract surface triangles from tetrahedral elements
el2tri = np.array(
[ # follow the order with correct normal
[[v[0], v[2], v[1]], [v[1], v[2], v[3]], [v[0], v[1], v[3]], [v[0], v[3], v[2]]] for v in self.elems
],
dtype=gs.np_int,
)
all_tri = el2tri.reshape((-1, 3))
all_tri_sorted = np.sort(all_tri, axis=1)
_, unique_idcs, cnt = np.unique(all_tri_sorted, axis=0, return_counts=True, return_index=True)
unique_tri = all_tri[unique_idcs]
surface_tri = unique_tri[cnt == 1]
self._surface_tri_np = surface_tri
self._n_surfaces = len(self._surface_tri_np)
if self._n_surfaces > 0:
self._n_surface_vertices = len(np.unique(self._surface_tri_np))
else:
self._n_surface_vertices = 0
tri2el = np.repeat(np.arange(self.elems.shape[0], dtype=gs.np_int)[:, np.newaxis], 4, axis=1)
unique_el = tri2el.flat[unique_idcs]
self._surface_el_np = unique_el[cnt == 1]
if isinstance(self.sim.coupler, SAPCoupler):
self.compute_pressure_field()
self.init_tgt_vars()
self.init_ckpt()
self._queried_states = QueriedStates()
self.active = False # This attribute is only used in forward pass. It should NOT be used during backward pass.
# ------------------------------------------------------------------------------------
# ----------------------------------- basic entity ops -------------------------------
# ------------------------------------------------------------------------------------
def _sanitize_verts_idx_local(self, verts_idx_local=None, envs_idx=None):
if verts_idx_local is None:
verts_idx_local = range(self.n_vertices)
if envs_idx is None:
verts_idx_local_ = broadcast_tensor(verts_idx_local, gs.tc_int, (-1,), ("verts_idx",))
else:
verts_idx_local_ = broadcast_tensor(
verts_idx_local, gs.tc_int, (len(envs_idx), -1), ("envs_idx", "verts_idx")
)
# FIXME: This check is too expensive
# if not (0 <= verts_idx_local_ & verts_idx_local_ < self.n_vertices).all():
# gs.raise_exception("Elements of `verts_idx_local' are out-of-range.")
return verts_idx_local_.contiguous()
def _sanitize_verts_tensor(self, tensor, dtype, verts_idx=None, envs_idx=None, element_shape=(), *, batched=True):
n_vertices = verts_idx.shape[-1] if verts_idx is not None else self.n_vertices
if batched:
assert envs_idx is not None
batch_shape = (len(envs_idx), n_vertices)
dim_names = ("envs_idx", "verts_idx", *("" for _ in element_shape))
else:
batch_shape = (n_vertices,)
dim_names = ("verts_idx", *("" for _ in element_shape))
tensor_shape = (*batch_shape, *element_shape)
return broadcast_tensor(tensor, dtype, tensor_shape, dim_names).contiguous()
[docs] def set_position(self, pos):
"""
Set the target position(s) for the FEM entity.
Parameters
----------
pos : torch.Tensor or array-like
The desired position(s). Can be:
- (3,): a single COM offset vector.
- (n_vertices, 3): per-vertex positions for all vertices.
- (n_envs, 3): per-environment COM offsets.
- (n_envs, n_vertices, 3): full batched per-vertex positions.
Raises
------
Exception
If the tensor shape is not supported.
"""
self._assert_active()
gs.logger.warning("Manually setting element positions. This is not recommended and could break gradient flow.")
pos = to_gs_tensor(pos)
is_valid = False
if pos.ndim == 1:
if pos.shape == (3,):
pos = self.init_positions_COM_offset + pos
self._tgt["pos"] = pos[None].tile((self._sim._B, 1, 1))
is_valid = True
elif pos.ndim == 2:
if pos.shape == (self.n_vertices, 3):
self._tgt["pos"] = pos[None].tile((self._sim._B, 1, 1))
is_valid = True
elif pos.shape == (self._sim._B, 3):
pos = self.init_positions_COM_offset[None] + pos[:, None]
self._tgt["pos"] = pos
is_valid = True
elif pos.ndim == 3:
if pos.shape == (self._sim._B, self.n_vertices, 3):
self._tgt["pos"] = pos
is_valid = True
if not is_valid:
gs.raise_exception("Tensor shape not supported.")
[docs] def set_velocity(self, vel):
"""
Set the target velocity(ies) for the FEM entity.
Parameters
----------
vel : torch.Tensor or array-like
The desired velocity(ies). Can be:
- (3,): a global velocity vector for all vertices.
- (n_vertices, 3): per-vertex velocities.
- (n_envs, 3): per-environment velocities broadcast to all vertices.
- (n_envs, n_vertices, 3): full batched per-vertex velocities.
Raises
------
Exception
If the tensor shape is not supported.
"""
self._assert_active()
gs.logger.warning("Manually setting element velocities. This is not recommended and could break gradient flow.")
vel = to_gs_tensor(vel)
is_valid = False
if vel.ndim == 1:
if vel.shape == (3,):
self._tgt["vel"] = vel.tile((self._sim._B, self.n_vertices, 1))
is_valid = True
elif vel.ndim == 2:
if vel.shape == (self.n_vertices, 3):
self._tgt["vel"] = vel[None].tile((self._sim._B, 1, 1))
is_valid = True
elif vel.shape == (self._sim._B, 3):
self._tgt["vel"] = vel[:, None].tile((1, self.n_vertices, 1))
is_valid = True
elif vel.ndim == 3:
if vel.shape == (self._sim._B, self.n_vertices, 3):
self._tgt["vel"] = vel
is_valid = True
if not is_valid:
gs.raise_exception("Tensor shape not supported.")
[docs] @assert_muscle
def set_actuation(self, actu):
"""
Set the actuation signal for the FEM entity.
Parameters
----------
actu : torch.Tensor or array-like
The actuation tensor. Can be:
- (): a single scalar for all groups.
- (n_groups,): group-level actuation.
- (n_envs, n_groups): batch of group-level actuation signals.
Raises
------
Exception
If the tensor shape is not supported or per-element actuation is attempted.
"""
self._assert_active()
actu = to_gs_tensor(actu)
is_valid = False
n_groups = self.material.n_groups
if actu.ndim == 0:
self._tgt["actu"] = actu.tile((self._sim._B, n_groups))
is_valid = True
elif actu.ndim == 1:
if actu.shape == (n_groups,):
self._tgt["actu"] = actu[None].tile((self._sim._B, 1))
is_valid = True
elif actu.shape == (self.n_elements,):
gs.raise_exception("Cannot set per-element actuation.")
elif actu.ndim == 2:
if actu.shape == (self._sim._B, n_groups):
self._tgt["actu"] = actu
is_valid = True
if not is_valid:
gs.raise_exception("Tensor shape not supported.")
[docs] def set_muscle(self, muscle_group=None, muscle_direction=None):
"""
Set the muscle group and/or muscle direction for the FEM entity.
Parameters
----------
muscle_group : torch.Tensor or array-like, optional
Tensor of shape (n_elements,) specifying the muscle group ID for each element.
muscle_direction : torch.Tensor or array-like, optional
Tensor of shape (n_elements, 3) specifying unit direction vectors for muscle forces.
Raises
------
AssertionError
If tensor shapes are incorrect or normalization fails.
"""
self._assert_active()
n_groups = self.material.n_groups
max_group_id = muscle_group.max().item()
muscle_group = to_gs_tensor(muscle_group)
assert muscle_group.shape == (self.n_elements,)
assert isinstance(max_group_id, int) and max_group_id < n_groups
self.set_muscle_group(muscle_group)
if muscle_direction is not None:
muscle_direction = to_gs_tensor(muscle_direction)
assert muscle_direction.shape == (self.n_elements, 3)
assert ((1.0 - muscle_direction.norm(dim=-1)).abs() < gs.EPS).all()
self.set_muscle_direction(muscle_direction)
[docs] def get_state(self):
state = FEMEntityState(self, self._sim.cur_step_global)
self.get_frame(self._sim.cur_substep_local, state.pos, state.vel, state.active)
# we store all queried states to track gradient flow
self._queried_states.append(state)
return state
[docs] def deactivate(self):
gs.logger.info(f"{self.__class__.__name__} <{self.id}> deactivated.")
self._tgt["act"] = gs.INACTIVE
self.active = False
[docs] def activate(self):
gs.logger.info(f"{self.__class__.__name__} <{self.id}> activated.")
self._tgt["act"] = gs.ACTIVE
self.active = True
# ------------------------------------------------------------------------------------
# ----------------------------------- instantiation ----------------------------------
# ------------------------------------------------------------------------------------
[docs] def instantiate(self, verts, elems):
"""
Initialize FEM entity with given vertices and elements.
Parameters
----------
verts : np.ndarray
Array of vertex positions with shape (n_vertices, 3).
elems : np.ndarray
Array of tetrahedral elements with shape (n_elements, 4), indexing into verts.
Raises
------
Exception
If no vertices are provided.
"""
verts = verts.astype(gs.np_float, copy=False)
elems = elems.astype(gs.np_int, copy=False)
# rotate
R = gu.quat_to_R(np.array(self.morph.quat, dtype=gs.np_float))
verts_COM = verts.mean(axis=0)
init_positions = (verts - verts_COM) @ R.T + verts_COM
if not init_positions.shape[0] > 0:
gs.raise_exception("Entity has zero vertices.")
self.init_positions = gs.tensor(init_positions)
self.init_positions_COM_offset = self.init_positions - gs.tensor(verts_COM)
self.elems = elems
[docs] def sample(self):
"""
Sample mesh and elements based on the entity's morph type.
For Cloth material, loads surface mesh directly without tetrahedralization.
For regular FEM materials, tetrahedralizes the mesh.
Raises
------
Exception
If the morph type is unsupported.
"""
from genesis.engine.materials.FEM.cloth import Cloth as ClothMaterial
is_cloth = isinstance(self.material, ClothMaterial)
self._uvs = None
if is_cloth:
# Cloth: load surface mesh directly (no tetrahedralization)
if isinstance(self.morph, gs.options.morphs.Mesh):
import trimesh
mesh = trimesh.load_mesh(self._morph.file)
verts = mesh.vertices * self._morph.scale + np.array(self._morph.pos)
faces = mesh.faces
# For cloth, we store faces as "elements" (treating them as surface elements)
self.instantiate(verts, faces)
# Load UVs from mesh (1:1 mapping for cloth).
# UVs are not always available in 3D file, in case they are missing we set the entity UVs to None when UVs are None,
# the solver will use 0 UVs for rendering. A mesh with 0 UVs means that no tangent directions can be recomputed,
# thus texture mapping and anisotropic surfaces will not work properly.
self._uvs = None
if isinstance(mesh.visual, trimesh.visual.texture.TextureVisuals) and mesh.visual.uv is not None:
self._uvs = mesh.visual.uv.astype(gs.np_float, copy=False)
else:
gs.raise_exception(f"Cloth material only supports Mesh morph. Got: {self.morph}.")
else:
# Regular FEM: tetrahedralize mesh
if isinstance(self.morph, gs.options.morphs.Sphere):
verts, elems = eu.sphere_to_elements(
pos=self._morph.pos,
radius=self._morph.radius,
tet_cfg=self.tet_cfg,
)
elif isinstance(self.morph, gs.options.morphs.Box):
verts, elems = eu.box_to_elements(
pos=self._morph.pos,
size=self._morph.size,
tet_cfg=self.tet_cfg,
)
elif isinstance(self.morph, gs.options.morphs.Cylinder):
verts, elems = eu.cylinder_to_elements()
elif isinstance(self.morph, gs.options.morphs.Mesh):
# We don't need to proces UVs here because the tetrahedralization process append new vertices
# and faces at the end of the vertex list, thus the original UVs are preserved at the beginning.
# We can't generate UVs for newly created internal vertices as it doesn't make sense but they're
# not used for rendering so it's fine.
verts, elems, self._uvs = eu.mesh_to_elements(
file=self._morph.file,
pos=self._morph.pos,
scale=self._morph.scale,
tet_cfg=self.tet_cfg,
)
else:
gs.raise_exception(f"Unsupported morph: {self.morph}.")
self.instantiate(*eu.split_all_surface_tets(verts, elems))
def _add_to_solver(self, in_backward=False):
from genesis.engine.materials.FEM.cloth import Cloth as ClothMaterial
is_cloth = isinstance(self.material, ClothMaterial)
if not in_backward:
self._step_global_added = self._sim.cur_step_global
gs.logger.info(
f"Entity {self.uid} added. class: {self.__class__.__name__}, morph: {self.morph.__class__.__name__}, size: ({self.n_elements}, {self.n_vertices}), material: {self.material}."
)
# Convert to appropriate numpy array types
verts_numpy = tensor_to_array(self.init_positions, dtype=gs.np_float)
uvs_np = self._uvs if self._uvs is not None else np.zeros((0, 2), dtype=gs.np_float)
if is_cloth:
# Cloth: add only vertices and surfaces for rendering (no physics computation)
gs.logger.info(
f"Entity {self.uid} is cloth - adding to FEM solver for rendering only (physics managed by IPC)"
)
self._solver._kernel_add_cloth_for_rendering(
f=self._sim.cur_substep_local,
n_surfaces=self._n_surfaces,
v_start=self._v_start,
s_start=self._s_start,
verts=verts_numpy,
tri2v=self._surface_tri_np,
uvs=uvs_np,
)
else:
# Regular FEM: add vertices, elements, and surfaces for physics and rendering
elems_np = self.elems.astype(gs.np_int, copy=False)
self._solver._kernel_add_elements(
f=self._sim.cur_substep_local,
mat_idx=self._material.idx,
mat_mu=self._material.mu,
mat_lam=self._material.lam,
mat_rho=self._material.rho,
mat_friction_mu=self._material.friction_mu,
n_surfaces=self._n_surfaces,
v_start=self._v_start,
el_start=self._el_start,
s_start=self._s_start,
verts=verts_numpy,
elems=elems_np,
tri2v=self._surface_tri_np,
tri2el=self._surface_el_np,
uvs=uvs_np,
)
self.active = True
[docs] def compute_pressure_field(self):
"""
Compute the pressure field for the FEM entity based on its tetrahedral elements.
For hydroelastic contact: https://drake.mit.edu/doxygen_cxx/group__hydroelastic__user__guide.html
Notes
-----
https://github.com/RobotLocomotion/drake/blob/master/geometry/proximity/make_mesh_field.cc
TODO: Add margin support
Drake's implementation of margin seems buggy.
"""
init_positions = tensor_to_array(self.init_positions)
signed_distance, *_ = igl.signed_distance(init_positions, init_positions, self._surface_tri_np)
signed_distance = signed_distance.astype(gs.np_float, copy=False)
unsigned_distance = np.abs(signed_distance)
max_distance = np.max(unsigned_distance)
if max_distance < gs.EPS:
gs.raise_exception(
f"Pressure field max distance is too small: {max_distance}. "
"This might be due to a mesh having no internal vertices."
)
self.pressure_field_np = unsigned_distance / max_distance * self.material._hydroelastic_modulus # normalize
# ------------------------------------------------------------------------------------
# ---------------------------- checkpoint and buffer ---------------------------------
# ------------------------------------------------------------------------------------
[docs] def init_tgt_keys(self):
"""
Initialize the keys used in target state management.
This defines which physical properties (e.g., position, velocity) will be tracked for checkpointing and buffering.
"""
self._tgt_keys = ["vel", "pos", "act", "actu"]
[docs] def init_tgt_vars(self):
"""
Initialize the target state variables and their buffers.
This sets up internal dictionaries to store per-step target values for properties like velocity, position, actuation, and activation.
"""
# temp variable to store targets for next step
self._tgt = dict()
self._tgt_buffer = dict()
self.init_tgt_keys()
for key in self._tgt_keys:
self._tgt[key] = None
self._tgt_buffer[key] = list()
[docs] def init_ckpt(self):
"""
Initialize the checkpoint storage dictionary.
Creates an empty container for storing simulation checkpoints.
"""
self._ckpt = dict()
[docs] def save_ckpt(self, ckpt_name):
"""
Save the current target state buffers to a named checkpoint.
Parameters
----------
ckpt_name : str
The name to identify the checkpoint.
Notes
-----
After saving, the internal target buffers are cleared to prepare for new input.
"""
if ckpt_name not in self._ckpt:
self._ckpt[ckpt_name] = {
"_tgt_buffer": dict(),
}
for key in self._tgt_keys:
self._ckpt[ckpt_name]["_tgt_buffer"][key] = list(self._tgt_buffer[key])
self._tgt_buffer[key].clear()
[docs] def load_ckpt(self, ckpt_name):
"""
Load a previously saved target state buffer from a named checkpoint.
Parameters
----------
ckpt_name : str
The name of the checkpoint to load.
Raises
------
KeyError
If the checkpoint name is not found.
"""
for key in self._tgt_keys:
self._tgt_buffer[key] = list(self._ckpt[ckpt_name]["_tgt_buffer"][key])
[docs] def reset_grad(self):
"""
Clear all stored gradient-related buffers.
This resets the target buffer and clears any queried states used for gradient tracking.
"""
for key in self._tgt_keys:
self._tgt_buffer[key].clear()
self._queried_states.clear()
def _assert_active(self):
if not self.active:
gs.raise_exception(f"{self.__class__.__name__} is inactive. Call `entity.activate()` first.")
# ------------------------------------------------------------------------------------
# ---------------------------- interfacing with solver -------------------------------
# ------------------------------------------------------------------------------------
[docs] def set_pos(self, f, pos):
"""
Set element positions in the solver.
Parameters
----------
f : int
Current substep/frame index.
pos : gs.Tensor
Tensor of shape (n_envs, n_vertices, 3) containing new positions.
"""
self._solver._kernel_set_elements_pos(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
pos=pos,
)
[docs] def set_pos_grad(self, f, pos_grad):
"""
Set gradient of element positions in the solver.
Parameters
----------
f : int
Current substep/frame index.
pos_grad : gs.Tensor
Tensor of shape (n_envs, n_vertices, 3) containing gradients of positions.
"""
self._solver._kernel_set_elements_pos_grad(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
pos_grad=pos_grad,
)
[docs] def set_vel(self, f, vel):
"""
Set element velocities in the solver.
Parameters
----------
f : int
Current substep/frame index.
vel : gs.Tensor
Tensor of shape (n_envs, n_vertices, 3) containing velocities.
"""
self._solver._kernel_set_elements_vel(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
vel=vel,
)
[docs] def set_vel_grad(self, f, vel_grad):
"""
Set gradient of element velocities in the solver.
Parameters
----------
f : int
Current substep/frame index.
vel_grad : gs.Tensor
Tensor of shape (n_envs, n_vertices, 3) containing gradients of velocities.
"""
self._solver._kernel_set_elements_vel_grad(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
vel_grad=vel_grad,
)
[docs] def set_actu(self, f, actu):
"""
Set actuation values for elements in the solver.
Parameters
----------
f : int
Current substep/frame index.
actu : gs.Tensor
Tensor of shape (n_envs, n_groups) specifying actuation values.
"""
self._solver._kernel_set_elements_actu(
f=f,
element_el_start=self._el_start,
n_elements=self.n_elements,
n_groups=self.material.n_groups,
actu=actu,
)
[docs] def set_actu_grad(self, f, actu_grad):
"""
Set gradient of actuation values in the solver.
Parameters
----------
f : int
Current substep/frame index.
actu_grad : gs.Tensor
Tensor of shape (n_envs, n_groups) specifying gradients of actuation.
"""
self._solver._kernel_set_elements_actu(
f=f,
element_el_start=self._el_start,
n_elements=self.n_elements,
actu_grad=actu_grad,
)
[docs] def set_active(self, f, active):
"""
Set the active status of each element.
Parameters
----------
f : int
Current substep/frame index.
active : int
Activity flag (gs.ACTIVE or gs.INACTIVE).
"""
self._solver._kernel_set_active(
f=f,
element_el_start=self._el_start,
n_elements=self.n_elements,
active=active,
)
[docs] @assert_muscle
def set_muscle_group(self, muscle_group):
"""
Set muscle group index for each element.
Parameters
----------
muscle_group : torch.Tensor
Tensor of shape (n_elements,) specifying muscle group IDs.
"""
self._solver._kernel_set_muscle_group(
element_el_start=self._el_start,
n_elements=self.n_elements,
muscle_group=muscle_group,
)
[docs] @assert_muscle
def set_muscle_direction(self, muscle_direction):
"""
Set muscle force direction for each element.
Parameters
----------
muscle_direction : torch.Tensor
Tensor of shape (n_elements, 3) with unit direction vectors.
"""
self._solver._kernel_set_muscle_direction(
element_el_start=self._el_start,
n_elements=self.n_elements,
muscle_direction=muscle_direction,
)
[docs] def set_vertex_constraints(
self, verts_idx_local, target_poss=None, link=None, is_soft_constraint=False, stiffness=0.0, envs_idx=None
):
"""
Set vertex constraints for specified vertices.
Parameters
----------
verts_idx_local : array_like
List of local vertex indices to constrain.
target_poss : array_like, shape (len(verts_idx), 3), optional
List of target positions [x, y, z] for each vertex. If not provided, the initial positions are used.
link : RigidLink
Optional rigid link for the vertices to follow, maintaining relative position.
is_soft_constraint: bool
By default, use a hard constraint directly sets position and zero velocity.
A soft constraint uses a spring force to pull the vertex towards the target position.
stiffness : float
Specify a spring stiffness for a soft constraint. Critical damping is applied.
envs_idx : array_like, optional
List of environment indices to apply the constraints to. If None, applies to all environments.
"""
if self._solver._use_implicit_solver and not self._solver._enable_vertex_constraints:
gs.raise_exception(
"This feature is disabled. Please set 'enable_vertex_constraints=True' when using FEM implicit solver."
)
if not self._solver._constraints_initialized:
self._solver.init_constraints()
use_current_poss = target_poss is None
envs_idx = self._scene._sanitize_envs_idx(envs_idx)
verts_idx_local = self._sanitize_verts_idx_local(verts_idx_local, envs_idx)
verts_idx = verts_idx_local + self._v_start
target_poss = self._sanitize_verts_tensor(target_poss, gs.tc_float, verts_idx, envs_idx, (3,))
if use_current_poss:
self._kernel_get_verts_pos(self._sim.cur_substep_local, target_poss, verts_idx)
if link is None:
link_idx = -1
link_init_pos = torch.zeros((self._sim._B, 3), dtype=gs.tc_float, device=gs.device)
link_init_quat = torch.zeros((self._sim._B, 4), dtype=gs.tc_float, device=gs.device)
else:
assert isinstance(link, RigidLink), "Only RigidLink is supported for vertex constraints."
link_idx = link.idx
link_init_pos = link.get_pos()
link_init_quat = link.get_quat()
if self._scene.n_envs == 0:
link_init_pos = link_init_pos[None]
link_init_quat = link_init_quat[None]
self._solver._kernel_set_vertex_constraints(
self._sim.cur_substep_local,
verts_idx,
target_poss,
is_soft_constraint,
stiffness,
link_idx,
link_init_pos,
link_init_quat,
envs_idx,
)
[docs] def update_constraint_targets(self, verts_idx_local, target_poss, envs_idx=None):
"""Update target positions for existing constraints."""
if not self._solver._constraints_initialized:
gs.logger.warning("Ignoring update_constraint_targets; constraints have not been initialized.")
return
assert target_poss is not None
envs_idx = self._scene._sanitize_envs_idx(envs_idx)
verts_idx_local = self._sanitize_verts_idx_local(verts_idx_local, envs_idx)
verts_idx = verts_idx_local + self._v_start
target_poss = self._sanitize_verts_tensor(target_poss, gs.tc_float, verts_idx, envs_idx, (3,))
self._solver._kernel_update_constraint_targets(verts_idx, target_poss, envs_idx)
[docs] def remove_vertex_constraints(self, verts_idx_local=None, envs_idx=None):
"""Remove constraints from specified vertices, or all if None."""
if not self._solver._constraints_initialized:
gs.logger.warning("Ignoring remove_vertex_constraints; constraints have not been initialized.")
return
# FIXME: GsTaichi 'fill' method is very inefficient. Try using zero-copy if possible.
if verts_idx_local is None:
self._solver.vertex_constraints.is_constrained.fill(0)
return
envs_idx = self._scene._sanitize_envs_idx(envs_idx)
verts_idx_local = self._sanitize_verts_idx_local(verts_idx_local, envs_idx)
verts_idx = verts_idx_local + self._v_start
self._solver._kernel_remove_specific_constraints(verts_idx, envs_idx)
@ti.kernel
def _kernel_get_verts_pos(self, f: ti.i32, pos: ti.types.ndarray(), verts_idx: ti.types.ndarray()):
# get current position of vertices
for i_v, i_b in ti.ndrange(verts_idx.shape[0], verts_idx.shape[1]):
i_global = verts_idx[i_v, i_b] + self.v_start
for j in ti.static(range(3)):
pos[i_b, i_v, j] = self._solver.elements_v[f, i_global, i_b].pos[j]
[docs] def get_el2v(self):
"""
Retrieve the element-to-vertex mapping.
Returns
-------
el2v : gs.Tensor
Tensor of shape (n_elements, 4) mapping each element to its vertex indices.
"""
el2v = gs.zeros((self.n_elements, 4), dtype=int, requires_grad=False, scene=self.scene)
self._solver._kernel_get_el2v(element_el_start=self._el_start, n_elements=self.n_elements, el2v=el2v)
return el2v
[docs] @ti.kernel
def get_frame(self, f: ti.i32, pos: ti.types.ndarray(), vel: ti.types.ndarray(), active: ti.types.ndarray()):
"""
Fetch the position, velocity, and activation state of the FEM entity at a specific substep.
Parameters
----------
f : int
The substep/frame index to fetch the state from.
pos : np.ndarray
Output array of shape (n_envs, n_vertices, 3) to store positions.
vel : np.ndarray
Output array of shape (n_envs, n_vertices, 3) to store velocities.
active : np.ndarray
Output array of shape (n_envs, n_elements) to store active flags.
"""
for i_v, i_b in ti.ndrange(self.n_vertices, self._sim._B):
i_global = i_v + self.v_start
for j in ti.static(range(3)):
pos[i_b, i_v, j] = self._solver.elements_v[f, i_global, i_b].pos[j]
vel[i_b, i_v, j] = self._solver.elements_v[f, i_global, i_b].vel[j]
for i_v, i_b in ti.ndrange(self.n_elements, self._sim._B):
i_global = i_v + self.el_start
active[i_b, i_v] = self._solver.elements_el_ng[f, i_global, i_b].active
[docs] @ti.kernel
def clear_grad(self, f: ti.i32):
"""
Zero out the gradients of position, velocity, and actuation for the current substep.
Parameters
----------
f : int
The substep/frame index for which to clear gradients.
Notes
-----
This method is primarily used during backward passes to manually reset gradients
that may be corrupted by explicit state setting.
"""
# TODO: not well-tested
for i_v, i_b in ti.ndrange(self.n_vertices, self._sim._B):
i_global = i_v + self.v_start
self._solver.elements_v.grad[f, i_global, i_b].pos = 0
self._solver.elements_v.grad[f, i_global, i_b].vel = 0
for i_v, i_b in ti.ndrange(self.n_elements, self._sim._B):
i_global = i_v + self.el_start
self._solver.elements_el.grad[f, i_global, i_b].actu = 0
# ------------------------------------------------------------------------------------
# --------------------------------- naming methods -----------------------------------
# ------------------------------------------------------------------------------------
def _get_morph_identifier(self) -> str:
morph = self._morph
if isinstance(morph, gs.morphs.Box):
return "fem_box"
if isinstance(morph, gs.morphs.Sphere):
return "fem_sphere"
if isinstance(morph, gs.morphs.Cylinder):
return "fem_cylinder"
if isinstance(morph, gs.morphs.Mesh):
return f"fem_{Path(morph.file).stem}"
return "fem_entity"
# ------------------------------------------------------------------------------------
# ----------------------------------- properties -------------------------------------
# ------------------------------------------------------------------------------------
@property
def n_vertices(self):
"""Number of vertices in the FEM entity."""
return len(self.init_positions)
@property
def n_elements(self):
"""Number of tetrahedral elements in the FEM entity."""
return len(self.elems)
@property
def n_surfaces(self):
"""Number of surface triangles extracted from the FEM mesh."""
return self._n_surfaces
@property
def v_start(self):
"""Global vertex index offset for this entity."""
return self._v_start
@property
def el_start(self):
"""Global element index offset for this entity."""
return self._el_start
@property
def s_start(self):
"""Global surface triangle index offset for this entity."""
return self._s_start
@property
def morph(self):
"""Morph specification used to generate the FEM mesh."""
return self._morph
@property
def material(self):
"""Material properties of the FEM entity."""
return self._material
@property
def surface(self):
"""Surface for rendering."""
return self._surface
@property
def n_surface_vertices(self):
"""Number of unique vertices involved in surface triangles."""
return self._n_surface_vertices
@property
def surface_triangles(self):
"""Surface triangles of the FEM mesh."""
return self._surface_tri_np
@property
def uvs(self):
"""UV coordinates for this entity's vertices, or None if not available."""
return self._uvs
@property
def tet_cfg(self):
"""Configuration of tetrahedralization."""
tet_cfg = mu.generate_tetgen_config_from_morph(self.morph)
return tet_cfg