import math
import numpy as np
import torch
import genesis as gs
from genesis.repr_base import RBC
from genesis.constants import IMAGE_TYPE
from genesis.utils.misc import ti_to_torch
from .rasterizer_context import SegmentationColorMap
# Optional imports for platform-specific functionality
try:
from gs_madrona.renderer_gs import MadronaBatchRendererAdapter
_MADRONA_AVAILABLE = True
except ImportError:
MadronaBatchRendererAdapter = None
_MADRONA_AVAILABLE = False
def _transform_camera_quat(quat):
# quat for Madrona needs to be transformed to y-forward
w, x, y, z = torch.unbind(quat, dim=-1)
return torch.stack([x + w, x - w, y - z, y + z], dim=-1) / math.sqrt(2.0)
def _make_tensor(data, *, dtype: torch.dtype = torch.float32):
return torch.tensor(data, dtype=dtype, device=gs.device)
class GenesisGeomRetriever:
def __init__(self, rigid_solver, seg_level):
self.rigid_solver = rigid_solver
self.seg_color_map = SegmentationColorMap(to_torch=True)
self.seg_level = seg_level
self.geom_idxc = None
self.default_geom_group = 2
self.default_enabled_geom_groups = np.array([self.default_geom_group], dtype=np.int32)
def build(self):
self.n_vgeoms = self.rigid_solver.n_vgeoms
self.geom_idxc = []
vgeoms = self.rigid_solver.vgeoms
for vgeom in vgeoms:
seg_key = self.get_seg_key(vgeom)
seg_idxc = self.seg_color_map.seg_key_to_idxc(seg_key)
self.geom_idxc.append(seg_idxc)
self.geom_idxc = torch.tensor(self.geom_idxc, dtype=torch.int32, device=gs.device)
self.seg_color_map.generate_seg_colors()
def get_seg_key(self, vgeom):
if self.seg_level == "geom":
return (vgeom.entity.idx, vgeom.link.idx, vgeom.idx)
elif self.seg_level == "link":
return (vgeom.entity.idx, vgeom.link.idx)
elif self.seg_level == "entity":
return vgeom.entity.idx
else:
gs.raise_exception(f"Unsupported segmentation level: {self.seg_level}")
# FIXME: Use a kernel to do it efficiently
def retrieve_rigid_meshes_static(self):
args = {}
vgeoms = self.rigid_solver.vgeoms
# Retrieve geom data
mesh_vertices = self.rigid_solver.vverts_info.init_pos.to_numpy()
mesh_faces = self.rigid_solver.vfaces_info.vverts_idx.to_numpy()
mesh_vertex_offsets = self.rigid_solver.vgeoms_info.vvert_start.to_numpy()
mesh_face_starts = self.rigid_solver.vgeoms_info.vface_start.to_numpy()
mesh_face_ends = self.rigid_solver.vgeoms_info.vface_end.to_numpy()
total_uv_size = 0
mesh_uvs = []
mesh_uv_offsets = []
for i in range(self.n_vgeoms):
mesh_faces[mesh_face_starts[i] : mesh_face_ends[i]] -= mesh_vertex_offsets[i]
geom_data_ids = []
for vgeom in vgeoms:
seg_key = self.get_seg_key(vgeom)
seg_id = self.seg_color_map.seg_key_to_idxc(seg_key)
geom_data_ids.append(seg_id)
if vgeom.uvs is not None:
mesh_uvs.append(vgeom.uvs.astype(np.float32))
mesh_uv_offsets.append(total_uv_size)
total_uv_size += vgeom.uvs.shape[0]
else:
mesh_uv_offsets.append(-1)
args["mesh_vertices"] = mesh_vertices
args["mesh_vertex_offsets"] = mesh_vertex_offsets
args["mesh_faces"] = mesh_faces
args["mesh_face_offsets"] = mesh_face_starts
args["mesh_texcoords"] = np.concatenate(mesh_uvs, axis=0) if mesh_uvs else np.empty((0, 2), np.float32)
args["mesh_texcoord_offsets"] = np.array(mesh_uv_offsets, np.int32)
args["geom_types"] = np.full((self.n_vgeoms,), 7, dtype=np.int32) # 7 stands for mesh
args["geom_groups"] = np.full((self.n_vgeoms,), self.default_geom_group, dtype=np.int32)
args["geom_data_ids"] = np.arange(self.n_vgeoms, dtype=np.int32)
args["geom_sizes"] = np.ones((self.n_vgeoms, 3), dtype=np.float32)
args["enabled_geom_groups"] = self.default_enabled_geom_groups
# Retrieve material data
geom_mat_ids = []
num_materials = 0
materials_indices = {}
mat_rgbas = []
mat_texture_indices = []
mat_texture_offsets = []
total_mat_textures = 0
num_textures = 0
texture_indices = {}
texture_widths = []
texture_heights = []
texture_nchans = []
texture_offsets = []
texture_data = []
total_texture_size = 0
for vgeom in vgeoms:
geom_surface = vgeom.surface
geom_textures = geom_surface.get_rgba(batch=True).textures
geom_texture_indices = []
for geom_texture in geom_textures:
if isinstance(geom_texture, gs.textures.ImageTexture) and geom_texture.image_array is not None:
texture_id = geom_texture.image_path
if texture_id not in texture_indices:
texture_idx = num_textures
if texture_id is not None:
texture_indices[texture_id] = texture_idx
texture_widths.append(geom_texture.image_array.shape[1])
texture_heights.append(geom_texture.image_array.shape[0])
assert geom_texture.channel() == 4
texture_nchans.append(geom_texture.channel())
texture_offsets.append(total_texture_size)
texture_data.append(geom_texture.image_array.flat)
num_textures += 1
total_texture_size += geom_texture.image_array.size
else:
texture_idx = texture_indices[texture_id]
geom_texture_indices.append(texture_idx)
# TODO: support batch rgba
geom_rgbas = [
geom_texture.image_color if isinstance(geom_texture, gs.textures.ImageTexture) else geom_texture.color
for geom_texture in geom_textures
]
for i in range(1, len(geom_rgbas)):
if not np.allclose(geom_rgbas[0], geom_rgbas[i], atol=gs.EPS):
gs.logger.warning("Batch Color is not yet supported. Use the first texture's color instead.")
break
geom_rgba = geom_rgbas[0]
mat_id = None
if len(geom_texture_indices) == 0:
geom_rgba_int = (np.array(geom_rgba) * 255.0).astype(np.uint32)
mat_id = geom_rgba_int[0] << 24 | geom_rgba_int[1] << 16 | geom_rgba_int[2] << 8 | geom_rgba_int[3]
if mat_id not in materials_indices:
material_idx = num_materials
if mat_id is not None:
materials_indices[mat_id] = material_idx
mat_rgbas.append(geom_rgba)
mat_texture_indices.extend(geom_texture_indices)
mat_texture_offsets.append(total_mat_textures)
num_materials += 1
total_mat_textures += len(geom_texture_indices)
else:
material_idx = materials_indices[mat_id]
geom_mat_ids.append(material_idx)
args["geom_mat_ids"] = np.array(geom_mat_ids, np.int32)
args["tex_widths"] = np.array(texture_widths, np.int32)
args["tex_heights"] = np.array(texture_heights, np.int32)
args["tex_nchans"] = np.array(texture_nchans, np.int32)
args["tex_data"] = np.concatenate(texture_data, axis=0) if texture_data else np.array([], np.uint8)
args["tex_offsets"] = np.array(texture_offsets, np.int64)
args["mat_rgba"] = np.array(mat_rgbas, np.float32)
args["mat_tex_ids"] = np.array(mat_texture_indices, np.int32)
args["mat_tex_offsets"] = np.array(mat_texture_offsets, np.int32)
return args
# FIXME: Use a kernel to do it efficiently
def retrieve_rigid_property_torch(self, num_worlds):
geom_rgb = torch.empty((0, self.n_vgeoms), dtype=torch.uint32, device=gs.device)
geom_mat_ids = torch.full((num_worlds, self.n_vgeoms), -1, dtype=torch.int32, device=gs.device)
geom_sizes = torch.ones((self.n_vgeoms, 3), dtype=torch.float32, device=gs.device)
geom_sizes = geom_sizes[None].repeat(num_worlds, 1, 1)
return geom_mat_ids, geom_rgb, geom_sizes
# FIXME: Use a kernel to do it efficiently
def retrieve_rigid_state_torch(self):
geom_pos = ti_to_torch(self.rigid_solver.vgeoms_state.pos)
geom_rot = ti_to_torch(self.rigid_solver.vgeoms_state.quat)
geom_pos = geom_pos.transpose(0, 1).contiguous()
geom_rot = geom_rot.transpose(0, 1).contiguous()
return geom_pos, geom_rot
class Light:
def __init__(self, pos, dir, color, intensity, directional, castshadow, cutoff, attenuation):
self._pos = pos
self._dir = tuple(dir / np.linalg.norm(dir))
self._color = color
self._intensity = intensity
self._directional = directional
self._castshadow = castshadow
self._cutoff = cutoff
self._attenuation = attenuation
@property
def pos(self):
return self._pos
@property
def dir(self):
return self._dir
@property
def color(self):
return self._color
@property
def intensity(self):
return self._intensity
@property
def directional(self):
return self._directional
@property
def castshadow(self):
return self._castshadow
@property
def cutoffRad(self):
return math.radians(self._cutoff)
@property
def cutoffDeg(self):
return self._cutoff
@property
def attenuation(self):
return self._attenuation
[docs]class BatchRenderer(RBC):
"""
This class is used to manage batch rendering
"""
def __init__(self, visualizer, renderer_options, vis_options):
self._visualizer = visualizer
self._lights = gs.List()
self._use_rasterizer = renderer_options.use_rasterizer
self._renderer = None
self._geom_retriever = GenesisGeomRetriever(self._visualizer.scene.rigid_solver, vis_options.segmentation_level)
self._data_cache = {}
self._t = -1
[docs] def add_light(self, pos, dir, color, intensity, directional, castshadow, cutoff, attenuation):
self._lights.append(Light(pos, dir, color, intensity, directional, castshadow, cutoff, attenuation))
[docs] def build(self):
"""
Build all cameras in the batch and initialize Moderona renderer
"""
if not _MADRONA_AVAILABLE:
gs.raise_exception("Madrona batch renderer is only supported on Linux x86-64.")
if gs.backend != gs.cuda:
gs.raise_exception("BatchRenderer requires CUDA backend.")
gpu_id = gs.device.index if gs.device.index is not None else 0
# Extract the complete list of non-debug cameras
self._cameras = gs.List([camera for camera in self._visualizer._cameras if not camera.debug])
if not self._cameras:
gs.raise_exception("Please add at least one camera when using BatchRender.")
# Build the geometry retriever
self._geom_retriever.build()
# Make sure that all cameras have identical resolution
try:
((camera_width, camera_height),) = set(camera.res for camera in self._cameras)
except ValueError as e:
gs.raise_exception_from("All cameras must have the exact same resolution when using BatchRender.", e)
self._renderer = MadronaBatchRendererAdapter(
geom_retriever=self._geom_retriever,
gpu_id=gs.device.index if gs.device.index is not None else 0,
num_worlds=max(self._visualizer.scene.n_envs, 1),
num_lights=len(self._lights),
cam_fovs_tensor=_make_tensor([camera.fov for camera in self._cameras]),
cam_znears_tensor=_make_tensor([camera.near for camera in self._cameras]),
cam_zfars_tensor=_make_tensor([camera.far for camera in self.cameras]),
cam_proj_types_tensor=_make_tensor(
[camera.model == "fisheye" for camera in self._cameras], dtype=torch.uint32
),
batch_render_view_width=camera_width,
batch_render_view_height=camera_height,
add_cam_debug_geo=False,
use_rasterizer=self._use_rasterizer,
)
self._renderer.init(
cam_pos_tensor=torch.stack([torch.atleast_2d(camera.get_pos()) for camera in self._cameras], dim=1),
cam_rot_tensor=_transform_camera_quat(
torch.stack([torch.atleast_2d(camera.get_quat()) for camera in self._cameras], dim=1)
),
lights_pos_tensor=_make_tensor([light.pos for light in self._lights]).reshape((-1, 3)),
lights_dir_tensor=_make_tensor([light.dir for light in self._lights]).reshape((-1, 3)),
lights_rgb_tensor=_make_tensor([light.color for light in self._lights]).reshape((-1, 3)),
lights_directional_tensor=_make_tensor([light.directional for light in self._lights], dtype=torch.bool),
lights_castshadow_tensor=_make_tensor([light.castshadow for light in self._lights], dtype=torch.bool),
lights_cutoff_tensor=_make_tensor([light.cutoffRad for light in self._lights]),
lights_attenuation_tensor=_make_tensor([light.attenuation for light in self._lights]),
lights_intensity_tensor=_make_tensor([light.intensity for light in self._lights]),
)
[docs] def update_scene(self, force_render: bool = False):
self._visualizer._context.update(force_render)
[docs] def render(self, rgb=True, depth=False, segmentation=False, normal=False, antialiasing=False, force_render=False):
"""
Render all cameras in the batch.
Parameters
----------
rgb : bool, optional
Whether to render the rgb image.
depth : bool, optional
Whether to render the depth image.
segmentation : bool, optional
Whether to render the segmentation image.
normal : bool, optional
Whether to render the normal image.
antialiasing : bool, optional
Whether to apply anti-aliasing.
force_render : bool, optional
Whether to force render the scene.
Returns
-------
rgb_arr : tuple of arrays
The sequence of rgb images associated with each camera.
depth_arr : tuple of arrays
The sequence of depth images associated with each camera.
segmentation_arr : tuple of arrays
The sequence of segmentation images associated with each camera.
normal_arr : tuple of arrays
The sequence of normal images associated with each camera.
"""
# Clear cache if requested or necessary
if force_render or self._t < self._visualizer.scene.t:
self._data_cache.clear()
# Fetch available cached data
request = (rgb, depth, segmentation, normal)
cache_key = (antialiasing,)
cached = [self._data_cache.get((img_type, cache_key), None) for img_type in IMAGE_TYPE]
# Force disabling rendering whenever cached data is already available
needed = tuple(req and arr is None for req, arr in zip(request, cached))
# Early return if everything requested is already cached
if not any(needed):
return tuple(arr if req else None for req, arr in zip(request, cached))
# Update scene
self.update_scene(force_render)
# Render only what is needed (flags still passed to renderer)
cameras_pos = torch.stack([torch.atleast_2d(camera.get_pos()) for camera in self._cameras], dim=1)
cameras_quat = torch.stack([torch.atleast_2d(camera.get_quat()) for camera in self._cameras], dim=1)
cameras_quat = _transform_camera_quat(cameras_quat)
render_flags = np.array(
(
*(
needed[img_type]
for img_type in (IMAGE_TYPE.RGB, IMAGE_TYPE.DEPTH, IMAGE_TYPE.NORMAL, IMAGE_TYPE.SEGMENTATION)
),
antialiasing,
),
dtype=np.uint32,
)
rendered = list(self._renderer.render(cameras_pos, cameras_quat, render_flags))
# convert seg geom idx to seg_idxc
if needed[IMAGE_TYPE.SEGMENTATION]:
seg_geoms = rendered[IMAGE_TYPE.SEGMENTATION]
mask = seg_geoms != -1
seg_geoms[mask] = self._geom_retriever.geom_idxc[seg_geoms[mask]]
seg_geoms[~mask] = 0
# Post-processing:
# * Remove alpha channel from RGBA
# * Squeeze env and channel dims if necessary
# * Split along camera dim
for img_type, data in enumerate(rendered):
if needed[img_type]:
data = data.swapaxes(0, 1)
if self._visualizer.scene.n_envs == 0:
data = data.squeeze(1)
rendered[img_type] = tuple(data[..., :3].squeeze(-1))
# Convert center distance depth to plane distance
if not self._use_rasterizer and needed[IMAGE_TYPE.DEPTH]:
rendered[IMAGE_TYPE.DEPTH] = tuple(
camera.distance_center_to_plane(depth_data)
for camera, depth_data in zip(self._cameras, rendered[IMAGE_TYPE.DEPTH])
)
# Update cache
self._t = self._visualizer.scene.t
for img_type, data in enumerate(rendered):
if needed[img_type]:
self._data_cache[(img_type, cache_key)] = rendered[img_type]
# Return in the required order, or None if not requested
return tuple(self._data_cache[(img_type, cache_key)] if needed[img_type] else None for img_type in IMAGE_TYPE)
[docs] def colorize_seg_idxc_arr(self, seg_idxc_arr):
return self._geom_retriever.seg_color_map.colorize_seg_idxc_arr(seg_idxc_arr)
[docs] def destroy(self):
self._lights.clear()
self._data_cache.clear()
if self._renderer is not None:
del self._renderer.madrona
self._renderer = None
[docs] def reset(self):
self._t = -1
@property
def lights(self):
return self._lights
@property
def cameras(self):
return self._cameras
@property
def seg_idxc_map(self):
return self._geom_retriever.seg_color_map.idxc_map