from abc import abstractmethod, ABC
from typing import Any, Dict, Generic, Iterable, List, Mapping, Tuple, TypeVar
import gymnasium as gym
import numpy as np
from .context import Context
from .utils import flatten
Observation = TypeVar("Observation")
[docs]class Encoder(Generic[Observation], ABC):
"""A trait for types that encodes the context of an agent into an observation."""
@property
@abstractmethod
def observation_space(self) -> gym.Space:
"""The output space of the encoder type."""
[docs] @abstractmethod
def encode(self, ctx: Context) -> Observation:
"""Encode the data in a given network context into an observation.
Arguments:
ctx: The local network context.
Returns:
An observation encoding properties of the provided context.
"""
[docs] def chain(self, others: Iterable["Encoder"]) -> "ChainedEncoder":
"""Chains this encoder together with adjoint set of encoders.
This method returns a :class:`ChainedEncoder` instance where the output
space reduces to a tuple with each element given by the output space
specified in each of the encoders provided.
"""
return ChainedEncoder(flatten([self, others]))
[docs] def reset(self):
"""Resets the encoder."""
def __repr__(self) -> str:
return repr(self.observation_space)
def __str__(self) -> str:
return str(self.observation_space)
[docs]class EmptyEncoder(Encoder[np.ndarray]):
"""Generates an empty observation."""
@property
def observation_space(self) -> gym.spaces.Box:
return gym.spaces.Box(-np.inf, np.inf, (1,))
[docs] def encode(self, _: Context) -> np.ndarray:
return np.zeros((1,))
[docs]class ChainedEncoder(Encoder[Tuple]):
"""Combines n encoders into a single encoder with a tuple action space.
Attributes:
encoders: An iterable collection of encoders which is flattened into a
list.
"""
def __init__(self, encoders: Iterable[Encoder]):
self.encoders: List[Encoder] = flatten(encoders)
@property
def observation_space(self) -> gym.Space:
return gym.spaces.Tuple(tuple(d.observation_space for d in self.encoders))
[docs] def encode(self, ctx: Context) -> Tuple:
return tuple(e.encode(ctx) for e in self.encoders)
[docs] def chain(self, others: Iterable["Encoder"]) -> "ChainedEncoder":
return ChainedEncoder(self.encoders + list(others))
[docs] def reset(self):
for encoder in self.encoders:
encoder.reset()
class DictEncoder(Encoder[Dict[str, Any]]):
"""Combines n encoders into a single encoder with a dict action space.
Attributes:
encoders: A mapping of encoder names to encoders.
"""
def __init__(self, encoders: Mapping[str, Encoder]):
self.encoders: Dict[str, Encoder] = dict(encoders)
@property
def observation_space(self) -> gym.Space:
return gym.spaces.Dict(
{name: encoder.observation_space for name, encoder in self.encoders.items()}
)
def encode(self, ctx: Context) -> Dict[str, Any]:
return {name: encoder.encode(ctx) for name, encoder in self.encoders.items()}
def reset(self):
for encoder in self.encoders.values():
encoder.reset()
[docs]class Constant(Encoder[np.ndarray]):
"""Encoder that always returns a constant valued Box Space.
Arguments:
shape: Shape of the returned box.
value: Value that the box is filled with.
"""
def __init__(self, shape: Tuple[int], value: float = 0.0) -> None:
self._shape = shape
self._value = value
@property
def observation_space(self) -> gym.spaces.Box:
return gym.spaces.Box(-np.inf, np.inf, shape=self._shape, dtype=np.float32)
[docs] def encode(self, _: Context) -> np.ndarray:
return np.full(self._shape, self._value)