import importlib
import os
import threading
from traceback import TracebackException
from typing import TYPE_CHECKING
import numpy as np
import OpenGL.error
import OpenGL.platform
import genesis as gs
import genesis.utils.geom as gu
from genesis.ext import pyrender
from genesis.repr_base import RBC
from genesis.utils.misc import redirect_libc_stderr, tensor_to_array
from genesis.utils.tools import Rate
from genesis.vis.keybindings import Key, KeyAction, Keybind, KeyMod
from genesis.vis.viewer_plugins import DefaultControlsPlugin
if TYPE_CHECKING:
from genesis.options.vis import ViewerOptions
from genesis.vis.viewer_plugins import ViewerPlugin
class ViewerLock:
def __init__(self, pyrender_viewer):
self._pyrender_viewer = pyrender_viewer
def __enter__(self):
self._pyrender_viewer.render_lock.acquire()
def __exit__(self, exc_type, exc_value, traceback):
self._pyrender_viewer.render_lock.release()
[docs]class Viewer(RBC):
def __init__(self, options: "ViewerOptions", context):
self._is_built = False
self._res = options.res
self._run_in_thread = options.run_in_thread
self._refresh_rate = options.refresh_rate
self._max_FPS = options.max_FPS
self._camera_init_pos = np.asarray(options.camera_pos, dtype=gs.np_float)
self._camera_init_lookat = np.asarray(options.camera_lookat, dtype=gs.np_float)
self._camera_up = np.asarray(options.camera_up, dtype=gs.np_float)
self._camera_fov = options.camera_fov
self._disable_help_text = options.disable_help_text
self._viewer_plugins: list["ViewerPlugin"] = []
if not options.disable_default_keybinds:
self._viewer_plugins.append(DefaultControlsPlugin())
# Validate viewer options
if any(e.shape != (3,) for e in (self._camera_init_pos, self._camera_init_lookat, self._camera_up)):
gs.raise_exception("ViewerOptions.camera_(pos|lookat|up) must be sequences of length 3.")
self._pyrender_viewer = None
self.context = context
self._followed_entity = None
self._follow_fixed_axis = None
self._follow_smoothing = None
self._follow_fix_orientation = None
self._follow_lookat = None
if self._max_FPS is not None:
self.rate = Rate(self._max_FPS)
[docs] def build(self, scene):
self.scene = scene
# set viewer camera
self.setup_camera()
# Try all candidate onscreen OpenGL "platforms" if none is specifically requested
opengl_platform_orig = os.environ.get("PYOPENGL_PLATFORM")
if opengl_platform_orig is None:
if gs.platform == "Windows":
all_opengl_platforms = ("wgl",) # same as "native"
elif gs.platform == "Linux":
# "native" is platform-specific ("egl" or "glx")
all_opengl_platforms = ("native", "egl", "glx", "osmesa")
else:
all_opengl_platforms = ("native",)
else:
if opengl_platform_orig == "osmesa" and gs.platform != "Linux":
gs.raise_exception("PYOPENGL_PLATFORM='osmesa' is only supported on Linux OS for now.")
all_opengl_platforms = (opengl_platform_orig,)
for i, platform in enumerate(all_opengl_platforms):
# Force re-import OpenGL platform
os.environ["PYOPENGL_PLATFORM"] = platform
importlib.reload(OpenGL.platform)
try:
gs.logger.debug(f"Trying to create OpenGL Context for PYOPENGL_PLATFORM='{platform}'...")
with open(os.devnull, "w") as stderr, redirect_libc_stderr(stderr):
self._pyrender_viewer = pyrender.Viewer(
context=self.context,
viewport_size=self._res,
run_in_thread=self._run_in_thread,
auto_start=False,
view_center=self._camera_init_lookat,
shadow=self.context.shadow,
plane_reflection=self.context.plane_reflection,
env_separate_rigid=self.context.env_separate_rigid,
disable_help_text=self._disable_help_text,
plugins=self._viewer_plugins,
viewer_flags={
"window_title": f"Genesis {gs.__version__}",
"refresh_rate": self._refresh_rate,
},
)
if not self._run_in_thread:
self._pyrender_viewer.start(auto_refresh=False)
self._pyrender_viewer.wait_until_initialized()
break
except (OpenGL.error.Error, RuntimeError) as e:
# Invalid OpenGL context. Trying another platform if any...
traceback = TracebackException.from_exception(e)
gs.logger.debug("".join(traceback.format()))
# Clear broken OpenGL context if it went this far
if self._pyrender_viewer is not None:
self._pyrender_viewer.close()
self._pyrender_viewer = None
if i == len(all_opengl_platforms) - 1:
raise
finally:
# Restore original platform systematically
del os.environ["PYOPENGL_PLATFORM"]
if opengl_platform_orig is not None:
os.environ["PYOPENGL_PLATFORM"] = opengl_platform_orig
self.lock = ViewerLock(self._pyrender_viewer)
gs.logger.info(f"Viewer created. Resolution: ~<{self._res[0]}×{self._res[1]}>~, max_FPS: ~<{self._max_FPS}>~.")
self._is_built = True
[docs] def run(self):
if self._pyrender_viewer is None:
gs.raise_exception("Viewer must be built successfully before calling this method.")
self._pyrender_viewer.run()
[docs] def stop(self):
if self.is_alive():
self._pyrender_viewer.close()
[docs] def is_alive(self):
return self._pyrender_viewer is not None and self._pyrender_viewer.is_active
[docs] def setup_camera(self):
yfov = self._camera_fov / 180.0 * np.pi
pose = gu.pos_lookat_up_to_T(self._camera_init_pos, self._camera_init_lookat, self._camera_up)
self._camera_up = pose[:3, 1].copy()
self._camera_node = self.context.add_node(pyrender.PerspectiveCamera(yfov=yfov), pose=pose)
[docs] def update(self, auto_refresh=None, force=False):
if self._followed_entity is not None:
self.update_following()
self._pyrender_viewer.update_on_sim_step()
with self.lock:
# Update context
self.context.update(force)
# Refresh viewer by default if and if this is possible
if auto_refresh is None:
viewer_thread = self._pyrender_viewer._thread or threading.main_thread()
auto_refresh = viewer_thread == threading.current_thread()
if auto_refresh and not self._pyrender_viewer.run_in_thread:
self._pyrender_viewer.refresh()
# lock FPS
if self._max_FPS is not None:
self.rate.sleep()
[docs] def close_offscreen(self, render_target):
return self._pyrender_viewer.close_offscreen(render_target)
[docs] def render_offscreen(self, camera_node, render_target, rgb=True, depth=False, seg=False, normal=False):
return self._pyrender_viewer.render_offscreen(camera_node, render_target, rgb, depth, seg, normal)
[docs] def set_camera_pose(self, pose=None, pos=None, lookat=None):
"""
Set viewer camera pose.
Parameters
----------
pose : [4,4] float, optional
Camera-to-world pose. If provided, `pos` and `lookat` will be ignored.
pos : (3,) float, optional
Camera position.
lookat : (3,) float, optional
Camera lookat point.
"""
if pose is None:
if pos is None:
pos = self._camera_init_pos
if lookat is None:
lookat = self._camera_init_lookat
up = self._camera_up
pose = gu.pos_lookat_up_to_T(pos, lookat, up)
self._camera_up = pose[:3, 1].copy()
else:
if np.array(pose).shape != (4, 4):
gs.raise_exception("pose should be a 4x4 matrix.")
self._pyrender_viewer._trackball.set_camera_pose(pose)
[docs] def follow_entity(self, entity, fixed_axis=(None, None, None), smoothing=None, fix_orientation=False):
"""
Set the viewer to follow a specified entity.
Parameters
----------
entity : genesis.Entity
The entity to follow.
fixed_axis : (float, float, float), optional
The fixed axis for the viewer's movement. For each axis, if None, the viewer will move freely. If a float, the viewer will be fixed on at that value.
For example, [None, None, None] will allow the viewer to move freely while following, [None, None, 0.5] will fix the viewer's z-axis at 0.5.
smoothing : float, optional
The smoothing factor in ]0,1[ for the viewer's movement. If None, no smoothing will be applied.
fix_orientation : bool, optional
If True, the viewer will maintain its orientation relative to the world. If False, the viewer will look at the base link of the entity.
"""
self._followed_entity = entity
self._follow_fixed_axis = fixed_axis
self._follow_smoothing = smoothing
self._follow_fix_orientation = fix_orientation
self._follow_lookat = self._camera_init_lookat
[docs] def update_following(self):
"""
Update the viewer position to follow the specified entity.
"""
entity_pos = tensor_to_array(self._followed_entity.get_pos())
if entity_pos.ndim > 1: # check for multiple envs
entity_pos = entity_pos[0]
# numpy < 2.0 doesn't support the copy keyword argument in np.asarray()
camera_transform = np.array(self._pyrender_viewer._trackball.pose, copy=True)
camera_pos = np.array(self._pyrender_viewer._trackball.pose[:3, 3])
if self._follow_smoothing is not None:
# Smooth viewer movement with a low-pass filter
camera_pos = self._follow_smoothing * camera_pos + (1 - self._follow_smoothing) * (
entity_pos + self._camera_init_pos
)
self._follow_lookat = (
self._follow_smoothing * self._follow_lookat + (1 - self._follow_smoothing) * entity_pos
)
else:
camera_pos = entity_pos + self._camera_init_pos
self._follow_lookat = entity_pos
for i, fixed_axis in enumerate(self._follow_fixed_axis):
# Fix the camera's position along the specified axis
if fixed_axis is not None:
camera_pos[i] = fixed_axis
if self._follow_fix_orientation:
# Keep the camera orientation fixed by overriding the lookat point
camera_transform[:3, 3] = camera_pos
self.set_camera_pose(pose=camera_transform)
else:
self.set_camera_pose(pos=camera_pos, lookat=self._follow_lookat)
[docs] @gs.assert_built
def register_keybinds(self, *keybinds: Keybind) -> None:
"""
Register a callback function to be called when a key is pressed.
Parameters
----------
keybinds : Keybind
One or more Keybind objects to register. See Keybind documentation for usage.
"""
self._pyrender_viewer.register_keybinds(*keybinds)
[docs] @gs.assert_built
def remap_keybind(
self,
keybind_name: str,
new_key: Key,
new_key_mods: tuple[KeyMod] | None,
new_key_action: KeyAction = KeyAction.PRESS,
) -> None:
"""
Remap an existing keybind by name to a new key combination.
Parameters
----------
keybind_name : str
The name of the keybind to remap.
new_key : int
The new key code from pyglet.
new_key_mods : tuple[KeyMod] | None
The new modifier keys pressed.
new_key_action : KeyAction, optional
The new type of key action. If not provided, the key action of the old keybind is used.
"""
self._pyrender_viewer.remap_keybind(
keybind_name,
new_key,
new_key_mods,
new_key_action,
)
[docs] @gs.assert_built
def remove_keybind(self, keybind_name: str) -> None:
"""
Remove an existing keybind by name.
Parameters
----------
keybind_name : str
The name of the keybind to remove.
"""
self._pyrender_viewer.remove_keybind(keybind_name)
[docs] def add_plugin(self, plugin: "ViewerPlugin") -> "ViewerPlugin":
"""
Add a viewer plugin to the viewer.
Parameters
----------
plugin : ViewerPlugin
The viewer plugin to add.
"""
self._viewer_plugins.append(plugin)
if self.is_built:
self._viewer.register_plugin(plugin)
return plugin
# ------------------------------------------------------------------------------------
# ----------------------------------- properties -------------------------------------
# ------------------------------------------------------------------------------------
@property
def is_built(self):
return self._is_built
@property
def res(self):
return self._res
@property
def refresh_rate(self):
return self._refresh_rate
@property
def max_FPS(self):
return self._max_FPS
@property
def camera_pos(self):
"""
Get the camera's current position.
"""
return np.array(self._pyrender_viewer._trackball._n_pose[:3, 3])
@property
def camera_lookat(self):
"""
Get the camera's current lookat point.
"""
pos = np.array(self._pyrender_viewer._trackball._n_pose[:3, 3])
z = self._pyrender_viewer._trackball._n_pose[:3, 2]
return pos - z
@property
def camera_pose(self):
"""
Get the camera's current pose represented by a 4x4 matrix.
"""
return np.array(self._pyrender_viewer._trackball._n_pose)
@property
def camera_up(self):
return self._camera_up
@property
def camera_fov(self):
return self._camera_fov