import csv
import os
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
import genesis as gs
from genesis.options.recorders import (
VideoFile as VideoFileWriterOptions,
CSVFile as CSVFileWriterOptions,
NPZFile as NPZFileWriterOptions,
)
from genesis.utils import tensor_to_array
from .base_recorder import Recorder
from .recorder_manager import register_recording
try:
import av
except ImportError:
pass
[docs]class BaseFileWriter(Recorder):
"""
Base class for file writers.
Handles filename counter when save_on_reset is True.
"""
[docs] def build(self):
super().build()
self.counter = 0
os.makedirs(os.path.abspath(os.path.dirname(self._options.filename)), exist_ok=True)
self._initialize_writer()
[docs] def reset(self, envs_idx=None):
super().reset(envs_idx)
# no envs specific saving supported
if self._options.save_on_reset:
self.cleanup()
self.counter += 1
self._initialize_writer()
def _get_filename(self):
if self._options.save_on_reset:
path, ext = os.path.splitext(self._options.filename)
return f"{path}_{self.counter}{ext}"
return self._options.filename
def _initialize_writer(self):
pass
[docs]@register_recording(VideoFileWriterOptions)
class VideoFileWriter(BaseFileWriter):
video_container: "av.container.OutputContainer | None"
video_stream: "av.video.stream.VideoStream | None"
video_frame: "av.video.frame.VideoFrame | None"
video_buffer: "np.ndarray | None"
[docs] def build(self):
self.video_container = None
self.video_stream = None
self.video_frame = None
self.video_buffer = None
self.fps = int(
round(
1.0 / (self._steps_per_sample * self._manager._step_dt)
if self._options.fps is None
else self._options.fps
)
)
super().build()
def _initialize_writer(self):
video_path = self._get_filename()
video_name = self._options.name or Path(video_path).stem
# Create ffmpeg video container
self.video_container = av.open(video_path, mode="w")
self.video_container.metadata["title"] = video_name
def _initialize_data(self, data):
assert isinstance(data, (np.ndarray, torch.Tensor))
is_color = data.ndim == 3 and data.shape[-1] == 3
if isinstance(data, np.ndarray):
is_dtype_int = np.issubdtype(data.dtype, np.integer)
else:
is_dtype_int = not torch.is_floating_point(data)
if data.ndim != 2 + is_color or not is_dtype_int:
gs.raise_exception(f"[{type(self).__name__}] Data must be either grayscale [H, W] or color [H, W, RGB]")
height, width, *_ = data.shape
# Create ffmpeg video stream
self.video_stream = self.video_container.add_stream(self._options.codec, rate=self.fps)
assert isinstance(self.video_stream, av.video.stream.VideoStream)
self.video_stream.width, self.video_stream.height = (width, height)
self.video_stream.pix_fmt = "yuv420p"
self.video_stream.bit_rate = int(self._options.bitrate * (8 * 1024**2))
self.video_stream.codec_context.options = self._options.codec_options
# Create frame storage once for efficiency
if is_color:
self.video_frame = av.VideoFrame(width, height, "rgb24")
frame_plane = self.video_frame.planes[0]
self.video_buffer = np.asarray(memoryview(frame_plane)).reshape((-1, frame_plane.line_size // 3, 3))
else:
self.video_frame = av.VideoFrame(width, height, "gray8")
frame_plane = self.video_frame.planes[0]
self.video_buffer = np.asarray(memoryview(frame_plane)).reshape((-1, frame_plane.line_size))
[docs] def process(self, data, cur_time):
if self.video_buffer is None:
self._initialize_data(data)
if isinstance(data, torch.Tensor):
data = tensor_to_array(data)
data = data.astype(np.uint8)
# Write frame
self.video_buffer[: data.shape[0], : data.shape[1]] = data
for packet in self.video_stream.encode(self.video_frame):
self.video_container.mux(packet)
[docs] def cleanup(self):
if self.video_container is not None:
# Finalize video recording.
# Note that 'video_stream' may be None if 'process' what never called.
if self.video_stream is not None:
for packet in self.video_stream.encode(None):
self.video_container.mux(packet)
self.video_container.close()
gs.logger.info(f'Video saved to "~<{self._options.filename}>~".')
self.video_container = None
self.video_stream = None
self.video_frame = None
self.video_buffer = None
@property
def run_in_thread(self) -> bool:
return False
[docs]@register_recording(CSVFileWriterOptions)
class CSVFileWriter(BaseFileWriter):
def _initialize_writer(self):
self.wrote_data = False
self.file_handle = open(self._get_filename(), "w", encoding="utf-8", newline="")
self.csv_writer = csv.writer(self.file_handle)
def _sanitize_to_list(self, value):
if isinstance(value, (torch.Tensor, np.ndarray)):
return value.reshape((-1,)).tolist()
elif isinstance(value, (int, float, bool)):
return [value]
elif isinstance(value, (list, tuple)):
return value
else:
gs.raise_exception(f"[{type(self).__name__}] Unsupported data type: {type(value)}")
[docs] def process(self, data, cur_time):
row_data = [cur_time]
if isinstance(data, dict):
for value in data.values():
row_data.extend(self._sanitize_to_list(value))
else:
row_data.extend(self._sanitize_to_list(data))
if not self.wrote_data: # write header
header = ["timestamp"]
if self._options.header:
header.extend(self._options.header)
else:
if isinstance(data, dict):
for key, val in data.items():
if hasattr(val, "__len__"):
header.extend([f"{key}_{i}" for i in range(len(val))])
else:
header.append(key)
else:
header.extend([f"data_{i}" for i in range(1, len(row_data))])
if len(header) != len(row_data):
gs.raise_exception(f"[{type(self).__name__}] header length does not match data length.")
self.csv_writer.writerow(header)
self.wrote_data = True
self.csv_writer.writerow(row_data)
if self._options.save_every_write:
self.file_handle.flush()
[docs] def cleanup(self):
if self.file_handle:
if self.wrote_data:
self.file_handle.close()
gs.logger.info(f'[CSVFileWriter] Saved to ~<"{self._get_filename()}">~.')
else:
self.file_handle.close()
os.remove(self._get_filename()) # delete empty file
@property
def run_in_thread(self) -> bool:
return True
[docs]@register_recording(NPZFileWriterOptions)
class NPZFileWriter(BaseFileWriter):
[docs] def build(self):
self.all_data: dict[str, list] = defaultdict(list)
super().build()
[docs] def process(self, data, cur_time):
self.all_data["timestamp"].append(cur_time)
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, torch.Tensor):
value = tensor_to_array(value)
assert isinstance(value, (int, float, bool, list, tuple, np.ndarray))
self.all_data[key].append(value)
else:
self.all_data["data"].append(tensor_to_array(data))
[docs] def cleanup(self):
filename = self._get_filename()
if self.all_data["timestamp"]: # at least one data point was collected
try:
np.savez_compressed(filename, **self.all_data)
except ValueError as error:
gs.logger.warning(f"NPZFileWriter: saving as dtype=object due to ValueError: {error}")
np.savez_compressed(filename, **{k: np.array(v, dtype=object) for k, v in self.all_data.items()})
gs.logger.info(f'[NPZFileWriter] Saved data with keys {list(self.all_data.keys())} to ~<"{filename}">~.')
self.all_data.clear()
@property
def run_in_thread(self) -> bool:
return True