Source code for genesis.vis.viewer

import importlib
import os
import threading
from traceback import TracebackException
from typing import TYPE_CHECKING

import numpy as np
import OpenGL.error
import OpenGL.platform

import genesis as gs
import genesis.utils.geom as gu
from genesis.ext import pyrender
from genesis.repr_base import RBC
from genesis.utils.misc import redirect_libc_stderr, tensor_to_array
from genesis.utils.tools import Rate
from genesis.vis.keybindings import Key, KeyAction, Keybind, KeyMod
from genesis.vis.viewer_plugins import DefaultControlsPlugin

if TYPE_CHECKING:
    from genesis.options.vis import ViewerOptions
    from genesis.vis.viewer_plugins import ViewerPlugin


class ViewerLock:
    def __init__(self, pyrender_viewer):
        self._pyrender_viewer = pyrender_viewer

    def __enter__(self):
        self._pyrender_viewer.render_lock.acquire()

    def __exit__(self, exc_type, exc_value, traceback):
        self._pyrender_viewer.render_lock.release()


[docs]class Viewer(RBC): def __init__(self, options: "ViewerOptions", context): self._is_built = False self._res = options.res self._run_in_thread = options.run_in_thread self._refresh_rate = options.refresh_rate self._max_FPS = options.max_FPS self._camera_init_pos = np.asarray(options.camera_pos, dtype=gs.np_float) self._camera_init_lookat = np.asarray(options.camera_lookat, dtype=gs.np_float) self._camera_up = np.asarray(options.camera_up, dtype=gs.np_float) self._camera_fov = options.camera_fov self._disable_help_text = options.disable_help_text self._viewer_plugins: list["ViewerPlugin"] = [] if not options.disable_default_keybinds: self._viewer_plugins.append(DefaultControlsPlugin()) # Validate viewer options if any(e.shape != (3,) for e in (self._camera_init_pos, self._camera_init_lookat, self._camera_up)): gs.raise_exception("ViewerOptions.camera_(pos|lookat|up) must be sequences of length 3.") self._pyrender_viewer = None self.context = context self._followed_entity = None self._follow_fixed_axis = None self._follow_smoothing = None self._follow_fix_orientation = None self._follow_lookat = None if self._max_FPS is not None: self.rate = Rate(self._max_FPS)
[docs] def build(self, scene): self.scene = scene # set viewer camera self.setup_camera() # Try all candidate onscreen OpenGL "platforms" if none is specifically requested opengl_platform_orig = os.environ.get("PYOPENGL_PLATFORM") if opengl_platform_orig is None: if gs.platform == "Windows": all_opengl_platforms = ("wgl",) # same as "native" elif gs.platform == "Linux": # "native" is platform-specific ("egl" or "glx") all_opengl_platforms = ("native", "egl", "glx", "osmesa") else: all_opengl_platforms = ("native",) else: if opengl_platform_orig == "osmesa" and gs.platform != "Linux": gs.raise_exception("PYOPENGL_PLATFORM='osmesa' is only supported on Linux OS for now.") all_opengl_platforms = (opengl_platform_orig,) for i, platform in enumerate(all_opengl_platforms): # Force re-import OpenGL platform os.environ["PYOPENGL_PLATFORM"] = platform importlib.reload(OpenGL.platform) try: gs.logger.debug(f"Trying to create OpenGL Context for PYOPENGL_PLATFORM='{platform}'...") with open(os.devnull, "w") as stderr, redirect_libc_stderr(stderr): self._pyrender_viewer = pyrender.Viewer( context=self.context, viewport_size=self._res, run_in_thread=self._run_in_thread, auto_start=False, view_center=self._camera_init_lookat, shadow=self.context.shadow, plane_reflection=self.context.plane_reflection, env_separate_rigid=self.context.env_separate_rigid, disable_help_text=self._disable_help_text, plugins=self._viewer_plugins, viewer_flags={ "window_title": f"Genesis {gs.__version__}", "refresh_rate": self._refresh_rate, }, ) if not self._run_in_thread: self._pyrender_viewer.start(auto_refresh=False) self._pyrender_viewer.wait_until_initialized() break except (OpenGL.error.Error, RuntimeError) as e: # Invalid OpenGL context. Trying another platform if any... traceback = TracebackException.from_exception(e) gs.logger.debug("".join(traceback.format())) # Clear broken OpenGL context if it went this far if self._pyrender_viewer is not None: self._pyrender_viewer.close() self._pyrender_viewer = None if i == len(all_opengl_platforms) - 1: raise finally: # Restore original platform systematically del os.environ["PYOPENGL_PLATFORM"] if opengl_platform_orig is not None: os.environ["PYOPENGL_PLATFORM"] = opengl_platform_orig self.lock = ViewerLock(self._pyrender_viewer) gs.logger.info(f"Viewer created. Resolution: ~<{self._res[0]}×{self._res[1]}>~, max_FPS: ~<{self._max_FPS}>~.") self._is_built = True
[docs] def run(self): if self._pyrender_viewer is None: gs.raise_exception("Viewer must be built successfully before calling this method.") self._pyrender_viewer.run()
[docs] def stop(self): if self.is_alive(): self._pyrender_viewer.close()
[docs] def is_alive(self): return self._pyrender_viewer is not None and self._pyrender_viewer.is_active
[docs] def setup_camera(self): yfov = self._camera_fov / 180.0 * np.pi pose = gu.pos_lookat_up_to_T(self._camera_init_pos, self._camera_init_lookat, self._camera_up) self._camera_up = pose[:3, 1].copy() self._camera_node = self.context.add_node(pyrender.PerspectiveCamera(yfov=yfov), pose=pose)
[docs] def update(self, auto_refresh=None, force=False): if self._followed_entity is not None: self.update_following() self._pyrender_viewer.update_on_sim_step() with self.lock: # Update context self.context.update(force) # Refresh viewer by default if and if this is possible if auto_refresh is None: viewer_thread = self._pyrender_viewer._thread or threading.main_thread() auto_refresh = viewer_thread == threading.current_thread() if auto_refresh and not self._pyrender_viewer.run_in_thread: self._pyrender_viewer.refresh() # lock FPS if self._max_FPS is not None: self.rate.sleep()
[docs] def close_offscreen(self, render_target): return self._pyrender_viewer.close_offscreen(render_target)
[docs] def render_offscreen(self, camera_node, render_target, rgb=True, depth=False, seg=False, normal=False): return self._pyrender_viewer.render_offscreen(camera_node, render_target, rgb, depth, seg, normal)
[docs] def set_camera_pose(self, pose=None, pos=None, lookat=None): """ Set viewer camera pose. Parameters ---------- pose : [4,4] float, optional Camera-to-world pose. If provided, `pos` and `lookat` will be ignored. pos : (3,) float, optional Camera position. lookat : (3,) float, optional Camera lookat point. """ if pose is None: if pos is None: pos = self._camera_init_pos if lookat is None: lookat = self._camera_init_lookat up = self._camera_up pose = gu.pos_lookat_up_to_T(pos, lookat, up) self._camera_up = pose[:3, 1].copy() else: if np.array(pose).shape != (4, 4): gs.raise_exception("pose should be a 4x4 matrix.") self._pyrender_viewer._trackball.set_camera_pose(pose)
[docs] def follow_entity(self, entity, fixed_axis=(None, None, None), smoothing=None, fix_orientation=False): """ Set the viewer to follow a specified entity. Parameters ---------- entity : genesis.Entity The entity to follow. fixed_axis : (float, float, float), optional The fixed axis for the viewer's movement. For each axis, if None, the viewer will move freely. If a float, the viewer will be fixed on at that value. For example, [None, None, None] will allow the viewer to move freely while following, [None, None, 0.5] will fix the viewer's z-axis at 0.5. smoothing : float, optional The smoothing factor in ]0,1[ for the viewer's movement. If None, no smoothing will be applied. fix_orientation : bool, optional If True, the viewer will maintain its orientation relative to the world. If False, the viewer will look at the base link of the entity. """ self._followed_entity = entity self._follow_fixed_axis = fixed_axis self._follow_smoothing = smoothing self._follow_fix_orientation = fix_orientation self._follow_lookat = self._camera_init_lookat
[docs] def update_following(self): """ Update the viewer position to follow the specified entity. """ entity_pos = tensor_to_array(self._followed_entity.get_pos()) if entity_pos.ndim > 1: # check for multiple envs entity_pos = entity_pos[0] # numpy < 2.0 doesn't support the copy keyword argument in np.asarray() camera_transform = np.array(self._pyrender_viewer._trackball.pose, copy=True) camera_pos = np.array(self._pyrender_viewer._trackball.pose[:3, 3]) if self._follow_smoothing is not None: # Smooth viewer movement with a low-pass filter camera_pos = self._follow_smoothing * camera_pos + (1 - self._follow_smoothing) * ( entity_pos + self._camera_init_pos ) self._follow_lookat = ( self._follow_smoothing * self._follow_lookat + (1 - self._follow_smoothing) * entity_pos ) else: camera_pos = entity_pos + self._camera_init_pos self._follow_lookat = entity_pos for i, fixed_axis in enumerate(self._follow_fixed_axis): # Fix the camera's position along the specified axis if fixed_axis is not None: camera_pos[i] = fixed_axis if self._follow_fix_orientation: # Keep the camera orientation fixed by overriding the lookat point camera_transform[:3, 3] = camera_pos self.set_camera_pose(pose=camera_transform) else: self.set_camera_pose(pos=camera_pos, lookat=self._follow_lookat)
[docs] @gs.assert_built def register_keybinds(self, *keybinds: Keybind) -> None: """ Register a callback function to be called when a key is pressed. Parameters ---------- keybinds : Keybind One or more Keybind objects to register. See Keybind documentation for usage. """ self._pyrender_viewer.register_keybinds(*keybinds)
[docs] @gs.assert_built def remap_keybind( self, keybind_name: str, new_key: Key, new_key_mods: tuple[KeyMod] | None, new_key_action: KeyAction = KeyAction.PRESS, ) -> None: """ Remap an existing keybind by name to a new key combination. Parameters ---------- keybind_name : str The name of the keybind to remap. new_key : int The new key code from pyglet. new_key_mods : tuple[KeyMod] | None The new modifier keys pressed. new_key_action : KeyAction, optional The new type of key action. If not provided, the key action of the old keybind is used. """ self._pyrender_viewer.remap_keybind( keybind_name, new_key, new_key_mods, new_key_action, )
[docs] @gs.assert_built def remove_keybind(self, keybind_name: str) -> None: """ Remove an existing keybind by name. Parameters ---------- keybind_name : str The name of the keybind to remove. """ self._pyrender_viewer.remove_keybind(keybind_name)
[docs] def add_plugin(self, plugin: "ViewerPlugin") -> "ViewerPlugin": """ Add a viewer plugin to the viewer. Parameters ---------- plugin : ViewerPlugin The viewer plugin to add. """ self._viewer_plugins.append(plugin) if self.is_built: self._viewer.register_plugin(plugin) return plugin
# ------------------------------------------------------------------------------------ # ----------------------------------- properties ------------------------------------- # ------------------------------------------------------------------------------------ @property def is_built(self): return self._is_built @property def res(self): return self._res @property def refresh_rate(self): return self._refresh_rate @property def max_FPS(self): return self._max_FPS @property def camera_pos(self): """ Get the camera's current position. """ return np.array(self._pyrender_viewer._trackball._n_pose[:3, 3]) @property def camera_lookat(self): """ Get the camera's current lookat point. """ pos = np.array(self._pyrender_viewer._trackball._n_pose[:3, 3]) z = self._pyrender_viewer._trackball._n_pose[:3, 2] return pos - z @property def camera_pose(self): """ Get the camera's current pose represented by a 4x4 matrix. """ return np.array(self._pyrender_viewer._trackball._n_pose) @property def camera_up(self): return self._camera_up @property def camera_fov(self): return self._camera_fov