from dataclasses import dataclass
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple
from .env import PhantomEnv
from .network import Network
from .supertype import Supertype
from .telemetry import logger
from .types import AgentID, StageID
from .views import AgentView, EnvView
[docs]class FSMValidationError(Exception):
"""
Error raised when validating the FSM when initialising the
:class:`FiniteStateMachineEnv`.
"""
[docs]class FSMRuntimeError(Exception):
"""
Error raised when validating FSM stage changes when running an episode using the
:class:`FiniteStateMachineEnv`.
"""
[docs]class FSMStage:
"""
Decorator used in the :class:`FiniteStateMachineEnv` to declare the finite state
machine structure and assign handler functions to stages.
A 'stage' corresponds to a state in the finite state machine, however to avoid any
confusion with Environment states we refer to them as stages.
Attributes:
id: The name of this stage.
acting_agents: The agents that will take an action at the end of the steps that
belong to this stage..
rewarded_agents: If provided, only the given agents will calculate and return a
reward at the end of the step for this stage. If not provided, a reward will
be computed for all acting agents for the current stage.
next_stages: The stages that this stage can transition to.
handler: Environment class method to be called when the FSM enters this stage.
"""
def __init__(
self,
stage_id: StageID,
acting_agents: Sequence[AgentID],
rewarded_agents: Optional[Sequence[AgentID]] = None,
next_stages: Optional[Sequence[StageID]] = None,
handler: Optional[Callable[[], StageID]] = None,
) -> None:
self.id = stage_id
self.acting_agents = acting_agents
self.rewarded_agents = rewarded_agents
self.next_stages = next_stages or []
self.handler = handler
def __call__(self, handler_fn: Callable[..., Optional[StageID]]):
setattr(handler_fn, "_decorator", self)
self.handler = handler_fn
return handler_fn
[docs]@dataclass(frozen=True)
class FSMEnvView(EnvView):
"""
Extension of the :class:`EnvView` class that records the current stage that the
environment is in.
"""
stage: StageID
[docs]class FiniteStateMachineEnv(PhantomEnv):
"""
Base environment class that allows implementation of a finite state machine to
handle complex environment multi-step setups.
This class should not be used directly and instead should be subclassed.
Use the :class:`FSMStage` decorator on handler methods within subclasses of this
class to register stages to the FSM.
A 'stage' corresponds to a state in the finite state machine, however to avoid any
confusion with Environment states we refer to them as stages.
Stage IDs can be anything type that is hashable, eg. strings, ints, enums.
Arguments:
num_steps: The maximum number of steps the environment allows per episode.
network: A Network class or derived class describing the connections between
agents and agents in the environment.
initial_stage: The initial starting stage of the FSM. When the reset() method is
called the environment is initialised into this stage.
env_supertype: Optional Supertype class instance for the environment. If this is
set, it will be sampled from and the :attr:`env_type` property set on the
class with every call to :meth:`reset()`.
agent_supertypes: Optional mapping of agent IDs to Supertype class instances. If
these are set, each supertype will be sampled from and the :attr:`type`
property set on the related agent with every call to :meth:`reset()`.
stages: List of FSM stages. FSM stages can be defined via this list or
alternatively via the :class:`FSMStage` decorator.
"""
def __init__(
self,
# from phantom env:
num_steps: int,
network: Network,
# fsm env specific:
initial_stage: StageID,
# from phantom env:
env_supertype: Optional[Supertype] = None,
agent_supertypes: Optional[Mapping[AgentID, Supertype]] = None,
# fsm env specific:
stages: Optional[Sequence[FSMStage]] = None,
) -> None:
super().__init__(num_steps, network, env_supertype, agent_supertypes)
self._initial_stage = initial_stage
self._rewards: Dict[AgentID, Optional[float]] = {}
self._observations: Dict[AgentID, Any] = {}
self._infos: Dict[AgentID, Dict[str, Any]] = {}
self._stages: Dict[StageID, FSMStage] = {}
self._current_stage = self.initial_stage
self.previous_stage: Optional[StageID] = None
# Register stages via optional class initialiser list
for stage in stages or []:
if stage.id not in self._stages:
self._stages[stage.id] = stage
# Register stages via FSMStage decorator
for attr_name in dir(self):
attr = getattr(self, attr_name)
if callable(attr):
handler_fn = attr
if hasattr(handler_fn, "_decorator"):
if handler_fn._decorator.id in self._stages:
raise FSMValidationError(
f"Found multiple stages with ID '{handler_fn._decorator.id}'"
)
self._stages[handler_fn._decorator.id] = handler_fn._decorator
# Check there is at least one stage
if len(self._stages) == 0:
raise FSMValidationError(
"No registered stages. Please use the 'FSMStage' decorator or the stage_definitions init parameter"
)
# Check initial stage is valid
if self.initial_stage not in self._stages:
raise FSMValidationError(
f"Initial stage '{self.initial_stage}' is not a valid stage"
)
# Check all 'next stages' are valid
for stage in self._stages.values():
for next_stage in stage.next_stages:
if next_stage not in self._stages:
raise FSMValidationError(
f"Next stage '{next_stage}' given in stage '{stage.id}' is not a valid stage"
)
# Check stages without handler have exactly one next stage
for stage in self._stages.values():
if len(stage.next_stages) != 1:
raise FSMValidationError(
f"Stage '{stage.id}' without handler must have exactly one next stage (got {len(stage.next_stages)})"
)
@property
def initial_stage(self) -> StageID:
"""Returns the initial stage of the FSM Env."""
return self._initial_stage
@property
def current_stage(self) -> StageID:
"""Returns the current stage of the FSM Env."""
return self._current_stage
[docs] def view(self, agent_views: Dict[AgentID, AgentView]) -> FSMEnvView:
"""Return an immutable view to the FSM environment's public state."""
return FSMEnvView(
self.current_step, self.current_step / self.num_steps, self.current_stage
)
[docs] def reset(
self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[Dict[AgentID, Any], Dict[str, Any]]:
"""
Reset the environment and return an initial observation.
This method resets the step count and the :attr:`network`. This includes all the
agents in the network.
Args:
seed: An optional seed to use for the new episode.
options : Additional information to specify how the environment is reset.
Returns:
- A dictionary mapping Agent IDs to observations made by the respective
agents. It is not required for all agents to make an initial observation.
- An optional dictionary with auxillary information, equivalent to the info
dictionary in `env.step()`.
"""
logger.log_reset()
# Reset the clock and stage
self._current_step = 0
self._current_stage = self.initial_stage
# Generate initial sampled values in samplers
for sampler in self._samplers:
sampler.sample()
if self.env_supertype is not None:
self.env_type = self.env_supertype.sample()
# Reset the network and call reset method on all agents in the network.
self.network.reset()
# Reset the agents' done statuses stored by the environment
self._terminations = set()
self._truncations = set()
# Set initial null reward values
self._rewards = {aid: None for aid in self.strategic_agent_ids}
# Generate all contexts for agents taking actions
acting_agents = self._stages[self.current_stage].acting_agents
self._make_ctxs(
[aid for aid in acting_agents if aid in self.strategic_agent_ids]
)
# Generate initial observations for agents taking actions
obs = {
ctx.agent.id: ctx.agent.encode_observation(ctx)
for ctx in self._ctxs.values()
}
logger.log_observations(obs)
return {k: v for k, v in obs.items() if v is not None}, {}
[docs] def step(self, actions: Mapping[AgentID, Any]) -> PhantomEnv.Step:
"""
Step the simulation forward one step given some set of agent actions.
Arguments:
actions: Actions output by the agent policies to be translated into
messages and passed throughout the network.
Returns:
A :class:`PhantomEnv.Step` object containing observations, rewards,
terminations, truncations and infos.
"""
# Increment the clock
self._current_step += 1
logger.log_step(self.current_step, self.num_steps)
logger.log_actions(actions)
logger.log_start_decoding_actions()
# Generate contexts for all agents taking actions / generating messages
self._make_ctxs(self.agent_ids)
# Decode action/generate messages for agents and send to the network
acting_agents = self._stages[self.current_stage].acting_agents
self._handle_acting_agents(acting_agents, actions)
env_handler = self._stages[self.current_stage].handler
if env_handler is None:
# If no handler has been set, manually resolve the network messages.
self.resolve_network()
next_stages = self._stages[self.current_stage].next_stages
if len(next_stages) == 0:
raise ValueError(
f"Current stage '{self.current_stage}' does not have an env handler or a next stage defined"
)
next_stage = next_stages[0]
elif hasattr(env_handler, "__self__"):
# If the FSMStage is defined with the stage definitions the handler will be
# a bound method of the env class.
next_stage = env_handler()
else:
# If the FSMStage is defined as a decorator the handler will be an unbound
# function.
next_stage = env_handler(self)
if next_stage not in self._stages[self.current_stage].next_stages:
raise FSMRuntimeError(
f"FiniteStateMachineEnv attempted invalid transition from '{self.current_stage}' to {next_stage}"
)
observations: Dict[AgentID, Any] = {}
rewards: Dict[AgentID, float] = {}
terminations: Dict[AgentID, bool] = {}
truncations: Dict[AgentID, bool] = {}
infos: Dict[AgentID, Dict[str, Any]] = {}
if self._stages[self.current_stage].rewarded_agents is None:
rewarded_agents = self.strategic_agent_ids
next_acting_agents = self.strategic_agent_ids
else:
rewarded_agents = self._stages[self.current_stage].rewarded_agents
next_acting_agents = self._stages[next_stage].acting_agents
for aid in self.strategic_agent_ids:
if aid in self._terminations or aid in self._truncations:
continue
ctx = self._ctxs[aid]
if aid in next_acting_agents:
obs = ctx.agent.encode_observation(ctx)
if obs is not None:
observations[aid] = obs
infos[aid] = ctx.agent.collect_infos(ctx)
if aid in rewarded_agents:
rewards[aid] = ctx.agent.compute_reward(ctx)
terminations[aid] = ctx.agent.is_terminated(ctx)
truncations[aid] = ctx.agent.is_truncated(ctx)
if terminations[aid]:
self._terminations.add(aid)
if truncations[aid]:
self._truncations.add(aid)
logger.log_step_values(observations, rewards, terminations, truncations, infos)
logger.log_metrics(self)
self._observations.update(observations)
self._rewards.update(rewards)
self._infos.update(infos)
logger.log_fsm_transition(self.current_stage, next_stage)
self.previous_stage, self._current_stage = self.current_stage, next_stage
terminations["__all__"] = self.is_terminated()
truncations["__all__"] = self.is_truncated()
if (
self.current_stage is None
or terminations["__all__"]
or truncations["__all__"]
):
logger.log_episode_done()
# This is the terminal stage, return most recent observations, rewards and
# infos from all agents.
return self.Step(
observations=self._observations,
rewards=self._rewards,
terminations=terminations,
truncations=truncations,
infos=self._infos,
)
# Otherwise not in terminal stage:
rewards = {aid: self._rewards[aid] for aid in observations}
return self.Step(observations, rewards, terminations, truncations, infos)