import io
import itertools
import sys
import threading
import time
from collections import defaultdict
from collections.abc import Sequence
from functools import partial
from typing import Any, Callable, T
import numpy as np
import torch
from PIL import Image
import genesis as gs
from genesis.options.recorders import (
BasePlotterOptions,
LinePlotterMixinOptions,
PyQtLinePlot as PyQtLinePlotterOptions,
MPLLinePlot as MPLLinePlotterOptions,
MPLImagePlot as MPLImagePlotterOptions,
)
from genesis.utils import has_display, tensor_to_array
from .base_recorder import Recorder
from .recorder_manager import RecorderManager, register_recording
IS_PYQTGRAPH_AVAILABLE = False
try:
import pyqtgraph as pg
IS_PYQTGRAPH_AVAILABLE = True
except ImportError:
pass
IS_MATPLOTLIB_AVAILABLE = False
try:
import matplotlib as mpl
IS_MATPLOTLIB_AVAILABLE = tuple(map(int, mpl.__version__.replace("+", ".").split(".")[:3])) >= (3, 7, 0)
except ImportError:
pass
MPL_PLOTTER_RESCALE_MIN_X = 0.5
MPL_PLOTTER_RESCALE_RATIO_X = 0.15
MPL_PLOTTER_RESCALE_RATIO_Y = 0.15
COLORS = itertools.cycle(("r", "g", "b", "c", "m", "y"))
def _data_to_array(data: Sequence) -> np.ndarray:
if isinstance(data, torch.Tensor):
data = tensor_to_array(data)
return np.atleast_1d(data)
[docs]class BasePlotter(Recorder):
def __init__(self, manager: "RecorderManager", options: BasePlotterOptions, data_func: Callable[[], T]):
if options.show_window is None:
options.show_window = has_display()
super().__init__(manager, options, data_func)
self._frames_buffer: list[np.ndarray] = []
[docs] def build(self):
super().build()
self.video_writer = None
if self._options.save_to_filename:
def _get_video_frame_buffer(plotter):
# Make sure that all the data in the pipe has been processed before rendering anything
if not plotter._frames_buffer:
if plotter._data_queue is not None and not plotter._data_queue.empty():
while not plotter._frames_buffer:
time.sleep(0.1)
return plotter._frames_buffer.pop(0)
self.video_writer = self._manager.add_recorder(
data_func=partial(_get_video_frame_buffer, self),
rec_options=gs.recorders.VideoFile(
filename=self._options.save_to_filename,
hz=self._options.hz,
),
)
[docs] def process(self, data, cur_time):
# Update plot
self._update_plot()
# Render frame if necessary
if self._options.save_to_filename:
self._frames_buffer.append(self.get_image_array())
[docs] def cleanup(self):
if self.video_writer is not None:
self.video_writer.stop()
self._frames_buffer.clear()
self.video_writer = None
def _update_plot(self):
"""
Update plot.
"""
raise NotImplementedError(f"[{type(self).__name__}] _update_plot() is not implemented.")
[docs] def get_image_array(self):
"""
Capture the plot image as a video frame.
Returns
-------
image_array : np.ndarray
The RGB image as a numpy array.
"""
raise NotImplementedError(f"[{type(self).__name__}] get_image_array() is not implemented.")
[docs]class LinePlotHelper:
"""
Helper class that manages line plot data.
Use composition pattern.
"""
def __init__(self, options: LinePlotterMixinOptions, data: dict[str, Sequence] | Sequence):
self.x_data: list[float] = []
self.y_data: defaultdict[str, defaultdict[str, list[float]]] = defaultdict(lambda: defaultdict(list))
self._history_length = options.history_length
# Note that these attributes will be set during first data processing or initialization
self._is_dict_data: bool | None = None
self._subplot_structure: dict[str, tuple[str, ...]] = {}
if isinstance(data, dict):
self._is_dict_data = True
if options.labels is not None:
assert isinstance(options.labels, dict), (
f"[{type(self).__name__}] Labels must be a dict when data is a dict"
)
assert set(options.labels.keys()) == set(data.keys()), (
f"[{type(self).__name__}] Label keys must match data keys"
)
for key in data.keys():
data_values = _data_to_array(data[key])
label_values = options.labels[key]
assert len(label_values) == len(data_values), (
f"[{type(self).__name__}] Label count must match data count for key '{key}'"
)
self._subplot_structure[key] = tuple(label_values)
else:
self._subplot_structure = {}
for key, values in data.items():
values = _data_to_array(values)
self._subplot_structure[key] = tuple(f"{key}_{i}" for i in range(len(values)))
else:
self._is_dict_data = False
data = _data_to_array(data)
if options.labels is not None:
if not isinstance(options.labels, Sequence):
options.labels = (options.labels,)
assert len(options.labels) == len(data), f"[{type(self).__name__}] Label count must match data count"
plot_labels = tuple(options.labels)
else:
plot_labels = tuple(f"data_{i}" for i in range(len(data)))
self._subplot_structure = {"main": plot_labels}
[docs] def clear_data(self):
self.x_data.clear()
self.y_data.clear()
[docs] def process(self, data, cur_time):
"""Process new data point and update plot."""
if self._is_dict_data:
processed_data = {}
for key, values in data.items():
if key not in self._subplot_structure:
continue # skip keys not included in subplot structure
values = _data_to_array(values)
processed_data[key] = values
else:
data = _data_to_array(data)
processed_data = {"main": data}
# Update time data
self.x_data.append(cur_time)
# Update y data for each subplot
for subplot_key, subplot_data in processed_data.items():
channel_labels = self._subplot_structure[subplot_key]
if len(subplot_data) != len(channel_labels):
gs.logger.warning(
f"[{type(self).__name__}] Data length ({len(subplot_data)}) doesn't match "
f"expected number of channels ({len(channel_labels)}) for subplot '{subplot_key}', skipping..."
)
continue
for i, channel_label in enumerate(channel_labels):
if i < len(subplot_data):
self.y_data[subplot_key][channel_label].append(float(subplot_data[i]))
# Maintain rolling history window
if len(self.x_data) > self._history_length:
self.x_data.pop(0)
for subplot_key in self.y_data:
for channel_label in self.y_data[subplot_key]:
try:
self.y_data[subplot_key][channel_label].pop(0)
except IndexError:
break # empty, nothing to do.
@property
def history_length(self):
return self._history_length
@property
def is_dict_data(self):
return self._is_dict_data
@property
def subplot_structure(self):
return self._subplot_structure
[docs]class BasePyQtPlotter(BasePlotter):
"""
Base class for PyQt based plotters.
"""
def __init__(self, manager: "RecorderManager", options: BasePlotterOptions, data_func: Callable[[], T]):
super().__init__(manager, options, data_func)
if threading.current_thread() is not threading.main_thread():
gs.raise_exception("Impossible to run PyQtPlotter in background thread.")
[docs] def build(self):
if not IS_PYQTGRAPH_AVAILABLE:
gs.raise_exception(
f"{type(self).__name__} pyqtgraph is not installed. Please install it with `pip install pyqtgraph`."
)
super().build()
self.app: pg.QtWidgets.QApplication | None = None
self.widget: pg.GraphicsLayoutWidget | None = None
self.plot_widgets: list[pg.PlotWidget] = []
if not pg.QtWidgets.QApplication.instance():
self.app = pg.QtWidgets.QApplication([])
else:
self.app = pg.QtWidgets.QApplication.instance()
self.widget = pg.GraphicsLayoutWidget(show=self._options.show_window, title=self._options.title)
if self._options.show_window:
gs.logger.info(f"[{type(self).__name__}] created PyQtGraph window")
self.widget.resize(*self._options.window_size)
[docs] def cleanup(self):
super().cleanup()
if self.widget:
try:
self.widget.close()
gs.logger.debug(f"[{type(self).__name__}] closed PyQtGraph window")
except Exception as e:
gs.logger.warning(f"[{type(self).__name__}] Error closing window: {e}")
finally:
self.plot_widgets.clear()
self.widget = None
@property
def run_in_thread(self) -> bool:
return False
[docs] def get_image_array(self):
"""
Capture the plot image as a video frame.
Returns
-------
image_array : np.ndarray
The image as a numpy array in (b,g,r,a) format.
"""
pixmap = self.widget.grab()
qimage = pixmap.toImage()
# pyqtgraph provides imageToArray but it always outputs (b,g,r,a) format
# https://pyqtgraph.readthedocs.io/en/latest/api_reference/functions.html#pyqtgraph.functions.imageToArray
return pg.imageToArray(qimage, copy=True, transpose=True)
[docs]@register_recording(PyQtLinePlotterOptions)
class PyQtLinePlotter(BasePyQtPlotter):
[docs] def build(self):
super().build()
self.line_plot = LinePlotHelper(options=self._options, data=self._data_func())
self.curves: dict[str, list[pg.PlotCurveItem]] = {}
# create plots for each subplot
for subplot_idx, (subplot_key, channel_labels) in enumerate(self.line_plot.subplot_structure.items()):
# add new row if not the first plot
if subplot_idx > 0:
self.widget.nextRow()
plot_widget = self.widget.addPlot(title=subplot_key if self.line_plot.is_dict_data else self._options.title)
plot_widget.setLabel("bottom", self._options.x_label)
plot_widget.setLabel("left", self._options.y_label)
plot_widget.showGrid(x=True, y=True, alpha=0.3)
plot_widget.addLegend()
# create lines for this subplot
subplot_curves = []
for color, channel_label in zip(COLORS, channel_labels):
curve = plot_widget.plot(pen=pg.mkPen(color=color, width=2), name=channel_label)
subplot_curves.append(curve)
self.plot_widgets.append(plot_widget)
if self._options.show_window:
plot_widget.show()
self.curves[subplot_key] = subplot_curves
[docs] def process(self, data, cur_time):
self.line_plot.process(data, cur_time)
super().process(data, cur_time)
def _update_plot(self):
# update all curves
for subplot_key, curves in self.curves.items():
channel_labels = self.line_plot.subplot_structure[subplot_key]
for curve, channel_label in zip(curves, channel_labels):
curve.setData(x=self.line_plot.x_data, y=self.line_plot.y_data[subplot_key][channel_label])
if self.app:
self.app.processEvents()
[docs] def cleanup(self):
super().cleanup()
self.line_plot.clear_data()
self.curves.clear()
[docs]class BaseMPLPlotter(BasePlotter):
"""
Base class for matplotlib based plotters.
"""
def __init__(self, manager: "RecorderManager", options: BasePlotterOptions, data_func: Callable[[], T]):
super().__init__(manager, options, data_func)
if threading.current_thread() is not threading.main_thread():
gs.raise_exception("Impossible to run MPLPlotter in background thread.")
[docs] def build(self):
if not IS_MATPLOTLIB_AVAILABLE:
gs.raise_exception(
f"{type(self).__name__} matplotlib is not installed. Please install it with `pip install matplotlib>=3.7.0`."
)
super().build()
import matplotlib.pyplot as plt
self.fig: plt.Figure | None = None
self._lock = threading.Lock()
# matplotlib figsize uses inches
dpi = mpl.rcParams.get("figure.dpi", 100)
self.figsize = (self._options.window_size[0] / dpi, self._options.window_size[1] / dpi)
def _show_fig(self):
if self._options.show_window:
self.fig.show()
gs.logger.info(f"[{type(self).__name__}] created matplotlib window")
[docs] def cleanup(self):
"""Clean up matplotlib resources."""
super().cleanup()
# Logger may not be available anymore
logger_exists = hasattr(gs, "logger")
if self.fig is not None:
try:
import matplotlib.pyplot as plt
plt.close(self.fig)
if logger_exists:
gs.logger.debug(f"[{type(self).__name__}] Closed matplotlib window")
except Exception as e:
if logger_exists:
gs.logger.warning(f"[{type(self).__name__}] Error closing window: {e}")
finally:
self.fig = None
[docs] def get_image_array(self):
"""
Capture the plot image as a video frame.
Returns
-------
image_array : np.ndarray
The RGB image as a numpy array.
"""
from matplotlib.backends.backend_agg import FigureCanvasAgg
self._lock.acquire()
if isinstance(self.fig.canvas, FigureCanvasAgg):
# Read internal buffer
width, height = self.fig.canvas.get_width_height(physical=True)
rgba_array_flat = np.frombuffer(self.fig.canvas.buffer_rgba(), dtype=np.uint8)
rgb_array = rgba_array_flat.reshape((height, width, 4))[..., :3]
# Rescale image if necessary
if (width, height) != tuple(self._options.window_size):
img = Image.fromarray(rgb_array)
img = img.resize(self._options.window_size, resample=Image.BILINEAR)
rgb_array = np.asarray(img)
else:
rgb_array = rgb_array.copy()
else:
# Slower but more generic fallback only if necessary
buffer = io.BytesIO()
self.fig.canvas.print_figure(buffer, format="png", dpi="figure")
buffer.seek(0)
img = Image.open(buffer)
rgb_array = np.asarray(img.convert("RGB"))
self._lock.release()
return rgb_array
@property
def run_in_thread(self) -> bool:
from matplotlib.backends.backend_agg import FigureCanvasAgg
if sys.platform == "darwin":
return False
if self._is_built:
assert self.fig is not None
# All Agg-based backends derives from the surfaceless Agg backend, so 'isinstance' cannot be used to
# discriminate the latter from others.
return type(self.fig.canvas) is FigureCanvasAgg
return not self._options.show_window
[docs]@register_recording(MPLLinePlotterOptions)
class MPLLinePlotter(BaseMPLPlotter):
[docs] def build(self):
super().build()
self.line_plot = LinePlotHelper(options=self._options, data=self._data_func())
import matplotlib.pyplot as plt
self.axes: list[plt.Axes] = []
self.lines: dict[str, list[plt.Line2D]] = {}
self.caches_bbox: list[Any] = []
self.cache_xmax: float = -1
# Create figure and subplots
n_subplots = len(self.line_plot.subplot_structure)
if n_subplots == 1:
self.fig, ax = plt.subplots(figsize=self.figsize)
self.axes = [ax]
else:
self.fig, axes = plt.subplots(n_subplots, 1, figsize=self.figsize, sharex=True, constrained_layout=True)
self.axes = axes if isinstance(axes, (list, tuple, np.ndarray)) else [axes]
self.fig.suptitle(self._options.title)
# Create lines for each subplot
for subplot_idx, (subplot_key, channel_labels) in enumerate(self.line_plot.subplot_structure.items()):
ax = self.axes[subplot_idx]
ax.set_xlabel(self._options.x_label)
ax.set_ylabel(self._options.y_label)
ax.grid(True, alpha=0.3)
if self.line_plot.is_dict_data and n_subplots > 1:
ax.set_title(subplot_key)
subplot_lines = []
for color, channel_label in zip(COLORS, channel_labels):
(line,) = ax.plot([], [], color=color, label=channel_label, linewidth=2)
subplot_lines.append(line)
self.lines[subplot_key] = subplot_lines
# Legend must be outside, otherwise it will not play well with blitting
self.fig.legend(ncol=sum(map(len, self.lines.values())), loc="outside lower center")
self.fig.canvas.draw()
for ax in self.axes:
self.caches_bbox.append(self.fig.canvas.copy_from_bbox(ax.bbox))
self._show_fig()
[docs] def process(self, data, cur_time):
self.line_plot.process(data, cur_time)
super().process(data, cur_time)
def _update_plot(self):
self._lock.acquire()
# Update limits for each subplot if necessary
limits_changed = False
if len(self.line_plot.x_data) > 1:
# First, check if the limits on y-axis must be extended to display all the available data
subplots_ylim_data = []
must_update_limit_y = False
for ax, subplot_key in zip(self.axes, self.lines.keys()):
subplot_y_data = self.line_plot.y_data[subplot_key]
subplot_ylim_data = None
if subplot_y_data:
all_y_values = list(itertools.chain.from_iterable(subplot_y_data.values()))
subplot_ylim_data = y_min_data, y_max_data = min(all_y_values), max(all_y_values)
y_min_plot, y_max_plot = ax.get_ylim()
if y_min_data < y_min_plot or y_max_plot < y_max_data:
must_update_limit_y = True
subplots_ylim_data.append(subplot_ylim_data)
# Next, adjust the limits on x-axis if they must be extended or adjusting y-axis is already planned
x_limits_changed = False
x_min_plot, x_max_plot = ax.get_xlim()
x_min_data, x_max_data = self.line_plot.x_data[0], self.line_plot.x_data[-1]
if must_update_limit_y or x_min_plot < 0.0 or x_max_plot < x_max_data:
x_min_plot = max(0.0, x_min_data)
x_max_plot = x_max_data + max(
MPL_PLOTTER_RESCALE_RATIO_X * (x_max_data - x_min_data), MPL_PLOTTER_RESCALE_MIN_X
)
ax.set_xlim((x_min_plot - gs.EPS, x_max_plot + gs.EPS))
x_limits_changed = True
# Finally, adjust the limits on y-axis if either x- or y-axis must be extended
if x_limits_changed or must_update_limit_y:
for ax, subplot_ylim_data in zip(self.axes, subplots_ylim_data):
if subplot_ylim_data is not None:
y_min_data, y_max_data = subplot_ylim_data
y_min_plot = y_min_data - MPL_PLOTTER_RESCALE_RATIO_Y * (y_max_data - y_min_data)
y_max_plot = y_max_data + MPL_PLOTTER_RESCALE_RATIO_Y * (y_max_data - y_min_data)
ax.set_ylim((y_min_plot - gs.EPS, y_max_plot + gs.EPS))
limits_changed = True
# Must redraw the entire figure if the limits have changed
if limits_changed:
self.fig.canvas.draw()
# Update background if the entire figure has been updated, or the buffer size has been exceeded
if limits_changed or (len(self.line_plot.x_data) > 1 and self.cache_xmax < self.line_plot.x_data[0] + gs.EPS):
self.caches_bbox = [self.fig.canvas.copy_from_bbox(ax.bbox) for ax in self.axes]
self.cache_xmax = self.line_plot.x_data[-2]
# Update lines for each subplot
for ax, cache_bbox, (subplot_key, subplot_lines) in zip(self.axes, self.caches_bbox, self.lines.items()):
# Restore background and update line data for this subplot
self.fig.canvas.restore_region(cache_bbox)
# Update lines
channel_labels = self.line_plot.subplot_structure[subplot_key]
for line, channel_label in zip(subplot_lines, channel_labels):
y_data = self.line_plot.y_data[subplot_key][channel_label]
line.set_data(self.line_plot.x_data, y_data)
ax.draw_artist(line)
# Blit the updated subplot
self.fig.canvas.blit(ax.bbox)
self.fig.canvas.flush_events()
self._lock.release()
[docs] def cleanup(self):
super().cleanup()
self.line_plot.clear_data()
self.lines.clear()
self.caches_bbox.clear()
self.cache_xmax = -1
[docs]@register_recording(MPLImagePlotterOptions)
class MPLImagePlotter(BaseMPLPlotter):
"""
Live image viewer using matplotlib.
The image data should be an array-like object with shape (H, W), (H, W, 1), (H, W, 3), or (H, W, 4).
"""
[docs] def build(self):
super().build()
import matplotlib.pyplot as plt
self.image_plot = None
self.background = None
self.fig, self.ax = plt.subplots(figsize=self.figsize)
self.fig.tight_layout(pad=0)
self.ax.set_axis_off()
self.fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
self.image_plot = self.ax.imshow(np.zeros((1, 1)), cmap="plasma", origin="upper", aspect="auto")
self._show_fig()
[docs] def process(self, data, cur_time):
"""Process new image data and update display."""
if isinstance(data, torch.Tensor):
img_data = tensor_to_array(data)
else:
img_data = np.asarray(data)
vmin, vmax = np.min(img_data), np.max(img_data)
current_vmin, current_vmax = self.image_plot.get_clim()
if vmin != current_vmin or vmax != current_vmax:
self.image_plot.set_clim(vmin, vmax)
self.fig.canvas.draw()
self.background = self.fig.canvas.copy_from_bbox(self.ax.bbox)
self.fig.canvas.restore_region(self.background)
self.image_plot.set_data(img_data)
self.ax.draw_artist(self.image_plot)
self.fig.canvas.blit(self.ax.bbox)
self.fig.canvas.flush_events()
[docs] def cleanup(self):
super().cleanup()
self.ax = None
self.image_plot = None
self.background = None