Source code for phantom.agents

from abc import ABC
from collections import defaultdict
from itertools import chain
from typing import (
    Any,
    Callable,
    DefaultDict,
    Dict,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
)

from .context import Context
from .decoders import Decoder
from .encoders import Encoder
from .message import Message, MsgPayload
from .reward_functions import RewardFunction
from .supertype import Supertype
from .telemetry import logger
from .types import AgentID
from .views import AgentView


Action = TypeVar("Action")
Observation = TypeVar("Observation")

MessageList = List[Tuple[AgentID, Message]]


[docs]class Agent(ABC): """ Representation of an agent in the network. Instances of :class:`phantom.Agent` occupy the nodes on the network graph. They are resonsible for storing and monitoring internal state, constructing :class:`View` instances and handling messages. Arguments: agent_id: Unique identifier for the agent. supertype: Optional :class:`Supertype` instance. When the agent's reset function is called the supertype will be sampled from and the values set as the agent's :attr:`type` property. Implementations can make use of the ``msg_handler`` function decorator: .. code-block:: python class SomeAgent(ph.Agent): ... @ph.agents.msg_handler(RequestMessage) def handle_request_msg(self, ctx: ph.Context, message: ph.Message): response_msgs = do_something_with_msg(message) return [response_msgs] """ def __init__( self, agent_id: AgentID, supertype: Optional[Supertype] = None, ) -> None: self._id = agent_id self.__handlers: DefaultDict[Type[MsgPayload], List[Handler]] = defaultdict( list ) self.supertype = supertype # This ensures that a supertype is sampled from and hence a type value is always # set on the agent to prevent 'Agent.type is undefined' type errors. Agent.reset(self) for name in dir(self): attr = getattr(self, name) if callable(attr) and hasattr(attr, "_message_type"): self.__handlers[attr._message_type].append(getattr(self, name)) @property def id(self) -> AgentID: """The unique ID of the agent.""" return self._id
[docs] def view(self, neighbour_id: Optional[AgentID] = None) -> Optional[AgentView]: """Return an immutable view to the agent's public state.""" return None
[docs] def pre_message_resolution(self, ctx: Context) -> None: """Perform internal, pre-message resolution updates to the agent."""
[docs] def post_message_resolution(self, ctx: Context) -> None: """Perform internal, post-message resolution updates to the agent."""
[docs] def handle_batch( self, ctx: Context, batch: Sequence[Message] ) -> List[Tuple[AgentID, MsgPayload]]: """ Handle a batch of messages from multiple potential senders. Arguments: ctx: A Context object representing agent's the local view of the environment. batch: The incoming batch of messages to handle. Returns: A list of receiver ID / message payload pairs to form into messages in response to further resolve. """ all_responses = [] for message in batch: logger.log_msg_recv(message) responses = self.handle_message(ctx, message) if responses is not None: all_responses += responses return all_responses
[docs] def handle_message( self, ctx: Context, message: Message ) -> List[Tuple[AgentID, MsgPayload]]: """ Handle a messages sent from other agents. The default implementation is the use the ``msg_handler`` function decorators. Arguments: ctx: A Context object representing agent's the local view of the environment. message: The contents of the message. Returns: A list of receiver ID / message payload pairs to form into messages in response to further resolve. """ ptype = type(message.payload) if ptype not in self.__handlers: raise ValueError( f"Unknown message type {ptype} in message sent from '{message.sender_id}' to '{self.id}'. Agent '{self.id}' needs a message handler function capable of receiving this mesage type." ) return list( chain.from_iterable( filter( lambda x: x is not None, ( bound_handler(ctx, message) for bound_handler in self.__handlers[ptype] ), ) ) )
def generate_messages(self, ctx: Context) -> List[Tuple[AgentID, MsgPayload]]: return []
[docs] def reset(self) -> None: """ Resets the Agent. Can be extended by subclasses to provide additional functionality. """ if self.supertype is not None: self.type = self.supertype.sample() elif hasattr(self, "Supertype"): try: self.type = self.Supertype().sample() except TypeError as e: raise Exception( f"Tried to initialise agent {self.id}'s Supertype with default values but failed:\n\t{e}" )
def __repr__(self) -> str: return f"[{self.__class__.__name__} {self.id}]"
[docs]class StrategicAgent(Agent): """ Representation of a behavioural agent in the network. Instances of :class:`phantom.Agent` occupy the nodes on the network graph. They are resonsible for storing and monitoring internal state, constructing :class:`View` instances and handling messages. Arguments: agent_id: Unique identifier for the agent. observation_encoder: Optional :class:`Encoder` instance, otherwise define an :meth:`encode_observation` method on the :class:`Agent` sub-class. action_decoder: Optional :class:`Decoder` instance, otherwise define an :meth:`decode_action` method on the :class:`Agent` sub-class. reward_function: Optional :class:`RewardFunction` instance, otherwise define an :meth:`compute_reward` method on the :class:`Agent` sub-class. supertype: Optional :class:`Supertype` instance. When the agent's reset function is called the supertype will be sampled from and the values set as the agent's :attr:`type` property. """ def __init__( self, agent_id: AgentID, observation_encoder: Optional[Encoder] = None, action_decoder: Optional[Decoder] = None, reward_function: Optional[RewardFunction] = None, supertype: Optional[Supertype] = None, ) -> None: super().__init__(agent_id, supertype) self.observation_encoder = observation_encoder self.action_decoder = action_decoder self.reward_function = reward_function if action_decoder is not None: self.action_space = action_decoder.action_space elif "action_space" not in dir(self): self.action_space = None if observation_encoder is not None: self.observation_space = observation_encoder.observation_space elif "observation_space" not in dir(self): self.observation_space = None
[docs] def encode_observation(self, ctx: Context) -> Observation: """ Encodes a local view of the environment state into a set of observations. Note: This method may be extended by sub-classes to provide additional functionality. Arguments: ctx: A Context object representing agent's the local view of the environment. Returns: A numpy array encoding the observations. """ if self.observation_encoder is None: raise NotImplementedError( f"Agent '{self.id}' does not have an Encoder instance set as 'observation_encoder' or a custom 'encode_observation' method defined" ) return self.observation_encoder.encode(ctx)
[docs] def decode_action( self, ctx: Context, action: Action ) -> Optional[List[Tuple[AgentID, MsgPayload]]]: """ Decodes an action taken by the agent policy into a set of messages to be sent to other agents in the network. Note: This method may be extended by sub-classes to provide additional functionality. Arguments: ctx: A Context object representing the agent's local view of the environment. action: The action taken by the agent. Returns: A list of receiver ID / message payload pairs to form into messages in response to further resolve. """ if self.action_decoder is None: raise NotImplementedError( f"Agent '{self.id}' does not have an Decoder instance set as 'action_decoder' or a custom 'decode_action' method defined" ) return self.action_decoder.decode(ctx, action)
[docs] def compute_reward(self, ctx: Context) -> float: """ Computes a reward value based on an agents current state. Note: This method may be extended by sub-classes to provide additional functionality. Arguments: ctx: A Context object representing the agent's local view of the environment. Returns: A float representing the present reward value. """ if self.reward_function is None: raise NotImplementedError( f"Agent '{self.id}' does not have an RewardFunction instance set as 'reward_function' or a custom 'compute_reward' method defined" ) return self.reward_function.reward(ctx)
[docs] def is_terminated(self, ctx: Context) -> bool: """ Indicates whether 'a `terminal state` (as defined under the MDP of the task) is reached' for the agent. The default logic is for the agent to be done only once all the timesteps have been executed. Note: This method may be extended by sub-classes to provide additional functionality. Arguments: ctx: A Context object representing the agent's local view of the environment. Returns: A boolean representing the terminal status of the agent. """ return False
[docs] def is_truncated(self, ctx: Context) -> bool: """ Indicates whether 'a truncation condition outside the scope of the MDP is satisfied' for the agent. Note: This method may be extended by sub-classes to provide additional functionality. Arguments: ctx: A Context object representing the agent's local view of the environment. Returns: A boolean representing the truncated status of the agent. """ return False
[docs] def collect_infos(self, ctx: Context) -> Dict[str, Any]: """ Provides diagnostic information about the agent, usefult for debugging. Note: This method may be extended by sub-classes to provide additional functionality. Arguments: ctx: A Context object representing the agent's local view of the environment. Returns: A dictionary containing informations about the agent """ return {}
Handler = Callable[[Context, Message], List[Tuple[AgentID, MsgPayload]]] def msg_handler(message_type: Type[MsgPayload]) -> Callable[[Handler], Handler]: def decorator(fn: Handler) -> Handler: setattr(fn, "_message_type", message_type) return fn return decorator