Source code for phantom.supertype

from abc import ABC
from dataclasses import dataclass
from typing import Any, Dict, Union

import gymnasium as gym
import numpy as np

from .utils.samplers import Sampler


ObsSpaceCompatibleTypes = Union[dict, list, np.ndarray, tuple]


[docs]@dataclass class Supertype(ABC): def sample(self) -> "Supertype": sampled_fields = {} for field_name in self.__dataclass_fields__: field = getattr(self, field_name) if isinstance(field, Sampler): if hasattr(self, "_managed"): sampled_fields[field_name] = field.value else: sampled_fields[field_name] = field.sample() else: sampled_fields[field_name] = field return self.__class__(**sampled_fields)
[docs] def to_obs_space_compatible_type(self) -> Dict[str, ObsSpaceCompatibleTypes]: """ Converts the parameters of the Supertype into a dict for use in observation spaces. """ return { name: _to_compatible_type(name, getattr(self, name)) for name in self.__dataclass_fields__ }
[docs] def to_obs_space(self, low=-np.inf, high=np.inf) -> gym.Space: """ Converts the parameters of the Supertype into a `gym.Space` representing the space. All elements of the space span the same range given by the `low` and `high` arguments. Arguments: low: Optional 'low' bound for the space (default is -∞) high: Optional 'high' bound for the space (default is ∞) """ return gym.spaces.Dict( { name: _to_obs_space(name, getattr(self, name), low, high) for name in self.__dataclass_fields__ } )
def _to_compatible_type(field: str, obj: Any) -> ObsSpaceCompatibleTypes: """Internal function.""" if isinstance(obj, dict): return {key: _to_compatible_type(key, value) for key, value in obj.items()} if isinstance(obj, (float, int)): return np.array([obj], dtype=np.float32) if isinstance(obj, list): return [ _to_compatible_type(f"{field}[{i}]", value) for i, value in enumerate(obj) ] if isinstance(obj, tuple): return tuple( _to_compatible_type(f"{field}[{i}]", value) for i, value in enumerate(obj) ) if isinstance(obj, np.ndarray): return obj raise ValueError( f"Can't encode field '{field}' with type '{type(obj)}' into obs space compatible type" ) def _to_obs_space(field: str, obj: Any, low: float, high: float) -> gym.Space: """Internal function.""" if isinstance(obj, dict): return gym.spaces.Dict( {key: _to_obs_space(key, value, low, high) for key, value in obj.items()} ) if isinstance(obj, float): return gym.spaces.Box(low, high, (1,), np.float32) if isinstance(obj, int): return gym.spaces.Box(low, high, (1,), np.float32) if isinstance(obj, (list, tuple)): return gym.spaces.Tuple( [ _to_obs_space(f"{field}[{i}]", value, low, high) for i, value in enumerate(obj) ] ) if isinstance(obj, np.ndarray): return gym.spaces.Box(low, high, obj.shape, np.float32) raise ValueError( f"Can't encode field '{field}' with type '{type(obj)}' into gym.Space" )