Source code for quaterion_models.utils.tensors
import torch
[docs]def move_to_device(obj, device: torch.device):
"""
Given a structure (possibly) containing Tensors on the CPU,
move all the Tensors to the specified GPU (or do nothing, if they should be on the CPU).
"""
if device == torch.device("cpu"):
return obj
elif isinstance(obj, torch.Tensor):
return obj.cuda(device)
elif isinstance(obj, dict):
return {key: move_to_device(value, device) for key, value in obj.items()}
elif isinstance(obj, list):
return [move_to_device(item, device) for item in obj]
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
# This is the best way to detect a NamedTuple, it turns out.
return obj.__class__(*(move_to_device(item, device) for item in obj))
elif isinstance(obj, tuple):
return tuple(move_to_device(item, device) for item in obj)
else:
return obj