import itertools
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import DefaultDict, List, Mapping, Optional, TYPE_CHECKING
import numpy as np
from .context import Context
from .telemetry import logger
from .types import AgentID
from .message import Message
if TYPE_CHECKING:
from .network import Network
[docs]class Resolver(ABC):
"""Network message resolver.
This type is responsible for resolution processing. That is, the order in which
(and any special logic therein) messages are handled in a Network.
In many cases, this type can be arbitrary since the sequence doesn't matter (i.e.
the problem is not path dependent). In other cases, however, this is not the case;
e.g. processing incoming market orders in an LOB.
Implementations of this class must provide implementations of the abstract methods
below.
Arguments:
enable_tracking: If True, the resolver will save all messages in a time-ordered
list that can be accessed with :attr:`tracked_messages`.
"""
def __init__(self, enable_tracking: bool = False) -> None:
self.enable_tracking = enable_tracking
self._tracked_messages: List[Message] = []
[docs] def push(self, message: Message) -> None:
"""Called by the Network to add messages to the resolver."""
if self.enable_tracking:
self._tracked_messages.append(message)
logger.log_msg_send(message)
self.handle_push(message)
[docs] def clear_tracked_messages(self) -> None:
"""Clears any stored messages.
Useful for when incrementally processing/storing batches of tracked messages.
"""
self._tracked_messages.clear()
@property
def tracked_messages(self) -> List[Message]:
"""
Returns all messages that have passed through the resolver if tracking is enabled.
"""
return self._tracked_messages
[docs] @abstractmethod
def handle_push(self, message: Message) -> None:
"""
Called by the resolver to handle batches of messages. Any further created
messages (e.g. responses from agents) must be handled by being passed to the
`push` method (not `handle_push`).
"""
raise NotImplementedError
[docs] @abstractmethod
def resolve(self, network: "Network", contexts: Mapping[AgentID, Context]) -> None:
"""Process queues messages for a (sub) set of network contexts.
Arguments:
network: An instance of the Network class to resolve.
contexts: The contexts for all agents for the current step.
"""
raise NotImplementedError
[docs] @abstractmethod
def reset(self) -> None:
"""Resets the resolver and clears any potential message queues.
Note:
Does not clear any tracked messages.
"""
raise NotImplementedError
[docs]class BatchResolver(Resolver):
"""
Resolver that handles messages in multiple discrete rounds. Each round, all agents
have the opportunity to respond to previously received messages with new messages.
These messages are held in a queue until all agents have been processed before being
consumed in the next round. Messages for each agent are delivered in a batch,
allowing the recipient agent to decide how to handle the messages within the batch.
Arguments:
enable_tracking: If True, the resolver will save all messages in a time-ordered
list that can be accessed with :attr:`tracked_messages`.
round_limit: The maximum number of rounds of messages to resolve. If the limit
is reached an exception will be thrown. By default the resolver will keep
resolving until no more messages are sent.
shuffle_batches: If True, the order in which messages for a particular
recipient are sent to the recipient will be randomised.
"""
def __init__(
self,
enable_tracking: bool = False,
round_limit: Optional[int] = None,
shuffle_batches: bool = False,
) -> None:
super().__init__(enable_tracking)
self.round_limit = round_limit
self.shuffle_batches = shuffle_batches
self.messages: DefaultDict[AgentID, List[Message]] = defaultdict(list)
[docs] def reset(self) -> None:
self.messages.clear()
[docs] def handle_push(self, message: Message) -> None:
self.messages[message.receiver_id].append(message)
[docs] def resolve(self, network: "Network", contexts: Mapping[AgentID, Context]) -> None:
iterator = (
itertools.count() if self.round_limit is None else range(self.round_limit)
)
for i in iterator:
if len(self.messages) == 0:
break
logger.log_resolver_round(i, self.round_limit)
processing_messages = self.messages
self.messages = defaultdict(list)
for receiver_id, messages in processing_messages.items():
if receiver_id not in contexts:
continue
msgs = [
m for m in messages if network.has_edge(m.sender_id, m.receiver_id)
]
if self.shuffle_batches:
np.random.shuffle(msgs)
ctx = contexts[receiver_id]
responses = ctx.agent.handle_batch(ctx, msgs)
if responses is not None:
for sub_receiver_id, sub_payload in responses:
network.send(receiver_id, sub_receiver_id, sub_payload)
if len(self.messages) > 0:
raise RuntimeError(
f"{len(self.messages)} message(s) still in queue after BatchResolver round limit reached."
)