Source code for phantom.utils.rollout

import io
import json
from collections import Counter
from dataclasses import asdict, dataclass
from typing import (
    Any,
    Dict,
    Iterable,
    List,
    Mapping,
    Optional,
    Tuple,
)

import numpy as np
import pandas as pd

from ..message import Message
from ..types import AgentID, StageID


[docs]@dataclass(frozen=True) class AgentStep: """Describes a step taken by a single agent in an episode.""" i: int observation: Optional[Any] reward: Optional[float] done: bool info: Optional[Dict[str, Any]] action: Optional[Any] stage: Optional[StageID] = None
[docs]@dataclass(frozen=True) class Step: """Describes a step taken in an episode.""" i: int observations: Dict[AgentID, Any] rewards: Dict[AgentID, float] terminations: Dict[AgentID, bool] truncations: Dict[AgentID, bool] infos: Dict[AgentID, Dict[str, Any]] actions: Dict[AgentID, Any] messages: Optional[List[Message]] = None stage: Optional[StageID] = None
[docs]@dataclass(frozen=True) class Rollout: rollout_id: int repeat_id: int env_config: Mapping[str, Any] rollout_params: Dict[str, Any] steps: List[Step] metrics: Dict[str, np.ndarray]
[docs] def observations_for_agent( self, agent_id: AgentID, drop_nones: bool = False, stages: Optional[Iterable[StageID]] = None, ) -> List[Optional[Any]]: """Helper method to filter all observations for a single agent. Arguments: agent_id: The ID of the agent to filter observations for. drop_nones: Drops any None values if True. stages: Optionally also filter by multiple stages. """ return [ step.observations.get(agent_id, None) for step in self.steps if (drop_nones is False or agent_id in step.observations) and (stages is None or step.stage in stages) ]
[docs] def rewards_for_agent( self, agent_id: AgentID, drop_nones: bool = False, stages: Optional[Iterable[StageID]] = None, ) -> List[Optional[float]]: """Helper method to filter all rewards for a single agent. Arguments: agent_id: The ID of the agent to filter rewards for. drop_nones: Drops any None values if True. stages: Optionally also filter by multiple stages. """ return [ step.rewards.get(agent_id, None) for step in self.steps if ( drop_nones is False or (agent_id in step.rewards and step.rewards[agent_id] is not None) ) and (stages is None or step.stage in stages) ]
[docs] def terminations_for_agent( self, agent_id: AgentID, drop_nones: bool = False, stages: Optional[Iterable[StageID]] = None, ) -> List[Optional[bool]]: """Helper method to filter all 'terminations' for a single agent. Arguments: agent_id: The ID of the agent to filter 'terminations' for. drop_nones: Drops any None values if True. stages: Optionally also filter by multiple stages. """ return [ step.terminations.get(agent_id, None) for step in self.steps if (drop_nones is False or agent_id in step.terminations) and (stages is None or step.stage in stages) ]
[docs] def truncations_for_agent( self, agent_id: AgentID, drop_nones: bool = False, stages: Optional[Iterable[StageID]] = None, ) -> List[Optional[bool]]: """Helper method to filter all 'truncations' for a single agent. Arguments: agent_id: The ID of the agent to filter 'truncations' for. drop_nones: Drops any None values if True. stages: Optionally also filter by multiple stages. """ return [ step.truncations.get(agent_id, None) for step in self.steps if (drop_nones is False or agent_id in step.truncations) and (stages is None or step.stage in stages) ]
[docs] def infos_for_agent( self, agent_id: AgentID, drop_nones: bool = False, stages: Optional[Iterable[StageID]] = None, ) -> List[Optional[Dict[str, Any]]]: """Helper method to filter all 'infos' for a single agent. Arguments: agent_id: The ID of the agent to filter 'infos' for. drop_nones: Drops any None values if True. stages: Optionally also filter by multiple stages. """ return [ step.infos.get(agent_id, None) for step in self.steps if (drop_nones is False or agent_id in step.infos) and (stages is None or step.stage in stages) ]
[docs] def actions_for_agent( self, agent_id: AgentID, drop_nones: bool = False, stages: Optional[Iterable[StageID]] = None, ) -> List[Optional[Any]]: """Helper method to filter all actions for a single agent. Arguments: agent_id: The ID of the agent to filter actions for. drop_nones: Drops any None values if True. stages: Optionally also filter by multiple stages. """ return [ step.actions.get(agent_id, None) for step in self.steps if (drop_nones is False or agent_id in step.actions) and (stages is None or step.stage in stages) ]
[docs] def steps_for_agent( self, agent_id: AgentID, stages: Optional[Iterable[StageID]] = None ) -> List[AgentStep]: """Helper method to filter all steps for a single agent. Arguments: agent_id: The ID of the agent to filter steps for. stages: Optionally also filter by multiple stages. """ if stages is None: steps = self.steps else: steps = [step for step in self.steps if step.stage in stages] return [ AgentStep( step.i, step.observations.get(agent_id, None), step.rewards.get(agent_id, None), step.terminations.get(agent_id, None), step.truncations.get(agent_id, None), step.infos.get(agent_id, None), step.actions.get(agent_id, None), step.stage, ) for step in steps ]
[docs] def count_actions( self, stages: Optional[Iterable[StageID]] = None ) -> List[Tuple[Any, int]]: """Helper method to count the occurances of all actions for all agents. Arguments: stages: Optionally filter by multiple stages. """ if stages is None: filtered_actions = ( action for step in self.steps for action in step.actions.values() ) else: filtered_actions = ( action for step in self.steps for action in step.actions.values() if step.stage in stages ) return Counter(filtered_actions).most_common()
[docs] def count_agent_actions( self, agent_id: AgentID, stages: Optional[Iterable[StageID]] = None ) -> List[Tuple[Any, int]]: """Helper method to count the occurances of all actions for a single agents. Arguments: agent_id: The ID of the agent to count actions for. stages: Optionally also filter by multiple stages. """ if stages is None: filtered_actions = (step.actions.get(agent_id, None) for step in self.steps) else: filtered_actions = ( step.actions.get(agent_id, None) for step in self.steps if step.stage in stages ) return Counter(filtered_actions).most_common()
def __getitem__(self, index: int): """Returns a step for a given index in the episode.""" try: return self.steps[index] except KeyError: raise KeyError(f"Index {index} not valid for trajectory")
def rollouts_to_dataframe( rollouts: Iterable[Rollout], avg_over_repeats: bool = True, index_value_precision: Optional[int] = None, ) -> pd.DataFrame: """ Converts a list of Rollouts into a MultiIndex DataFrame with rollout params as the indexes and metrics as the columns. Arguments: rollouts: The list/iterator of Phantom Rollout objects to use. avg_over_repeats: If True will average all metric values over each set of repeats. This is very useful for reducing the overall data size if individual rollouts are not required. index_value_precision: If given will round the index values to the given precision and convert to strings. This can be useful for avoiding floating point inaccuracies when indexing (eg. 2.0 != 2.000000001). Returns: A Pandas DataFrame containing the results. """ # Consume iterator (if applicable), throw away everything except params and metrics rollouts = [(rollout.rollout_params, rollout.metrics) for rollout in rollouts] index_cols = list(rollouts[0][0].keys()) df = pd.DataFrame([{**params, **metrics} for params, metrics in rollouts]) if index_value_precision is not None: for col in index_cols: df[col] = df[col].round(index_value_precision).astype(str) if len(index_cols) > 0: if avg_over_repeats: df = df.groupby(index_cols).mean().reset_index() df = df.set_index(index_cols) return df def rollouts_to_jsonl( rollouts: Iterable[Rollout], file_obj: io.TextIOBase, human_readable: bool = False, ) -> None: """ Writes multiple rollouts to a file using the JSONL (JSON Lines) format. Arguments: rollouts: The list/iterator of Phantom Rollout objects to use. file_obj: A writable file object to output to. human_readable: If True the output will be 'pretty printed'. """ for rollout in rollouts: json.dump( rollout, file_obj, indent=2 if human_readable else None, cls=RolloutJSONEncoder, ) file_obj.write("\n") file_obj.flush() class RolloutJSONEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, np.ndarray): return o.tolist() if isinstance(o, np.bool_): return bool(o) if isinstance(o, np.floating): return float(o) if isinstance(o, np.number): return int(o) if isinstance(o, Rollout): return asdict(o) if isinstance(o, Step): return asdict(o) return json.JSONEncoder.default(self, o)