Source code for genesis.grad.creation_ops

import sys
from functools import wraps

import torch

import genesis as gs

from .tensor import Tensor

_torch_ops = (
    torch.tensor,
    torch.asarray,
    torch.as_tensor,
    torch.as_strided,
    torch.from_numpy,
    torch.zeros,
    torch.zeros_like,
    torch.ones,
    torch.ones_like,
    torch.arange,
    torch.range,
    torch.linspace,
    torch.logspace,
    torch.eye,
    torch.empty,
    torch.empty_like,
    torch.empty_strided,
    torch.full,
    torch.full_like,
    torch.rand,
    torch.rand_like,
    torch.randn,
    torch.randn_like,
    torch.randint,
    torch.randint_like,
    torch.randperm,
)


[docs]def torch_op_wrapper(torch_op): @wraps(torch_op) def _wrapper(*args, dtype=None, requires_grad=False, scene=None, **kwargs): if "device" in kwargs: gs.raise_exception("Device selection not supported. All genesis tensors are on GPU.") if not gs._initialized: gs.raise_exception("Genesis not initialized yet.") if torch_op is torch.from_numpy: torch_tensor = torch_op(*args) else: torch_tensor = torch_op(*args, **kwargs) return from_torch(torch_tensor, dtype, requires_grad, detach=True, scene=scene) _wrapper.__doc__ = ( f"This method is the genesis wrapper of `torch.{torch_op.__name__}`.\n\n------------------\n{_wrapper.__doc__}" ) return _wrapper
[docs]def from_torch(torch_tensor, dtype=None, requires_grad=False, detach=True, scene=None): """ By default, detach is True, meaning that this function returns a new leaf tensor which is not connected to torch_tensor's computation gragh. """ if dtype is None: dtype = torch_tensor.dtype if dtype in (float, torch.float32, torch.float64): dtype = gs.tc_float elif dtype in (int, torch.int32, torch.int64): dtype = gs.tc_int elif dtype in (bool, torch.bool): dtype = torch.bool else: gs.raise_exception(f"Unsupported dtype: {dtype}") if torch_tensor.requires_grad and (not detach) and (not requires_grad): gs.logger.warning( "The parent torch tensor requires grad and detach is set to False. Ignoring requires_grad=False." ) requires_grad = True gs_tensor = Tensor(torch_tensor.to(device=gs.device, dtype=dtype), scene=scene).clone() if detach: gs_tensor = gs_tensor.detach(sceneless=False) if requires_grad: gs_tensor = gs_tensor.requires_grad_() return gs_tensor
for _torch_op in _torch_ops: setattr(sys.modules[__name__], _torch_op.__name__, torch_op_wrapper(_torch_op))