Source code for phantom.utils.rllib.train

import os
import tempfile
from datetime import datetime
from inspect import isclass
from pathlib import Path
from typing import (
    Any,
    Dict,
    List,
    Mapping,
    Optional,
    Tuple,
    Type,
    Union,
)

import cloudpickle
import gymnasium as gym
import numpy as np
import ray
import rich.pretty
from ray import rllib
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.evaluation import Episode, MultiAgentEpisode
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import TensorStructType, TensorType

from ...agents import Agent
from ...env import PhantomEnv
from ...metrics import Metric, logging_helper
from ...policy import Policy
from ...types import AgentID
from .. import check_env_config, rich_progress, show_pythonhashseed_warning
from .wrapper import RLlibEnvWrapper


PolicyClass = Union[Type[Policy], Type[rllib.Policy]]

PolicyMapping = Mapping[
    str,
    Union[
        Type[Agent],
        List[AgentID],
        Tuple[PolicyClass, Type[Agent]],
        Tuple[PolicyClass, Type[Agent], Mapping[str, Any]],
        Tuple[PolicyClass, List[AgentID]],
        Tuple[PolicyClass, List[AgentID], Mapping[str, Any]],
    ],
]


[docs]def train( algorithm: str, env_class: Type[PhantomEnv], policies: PolicyMapping, iterations: int, checkpoint_freq: Optional[int] = None, num_workers: Optional[int] = None, env_config: Optional[Mapping[str, Any]] = None, rllib_config: Optional[Mapping[str, Any]] = None, metrics: Optional[Mapping[str, Metric]] = None, results_dir: str = ray.tune.result.DEFAULT_RESULTS_DIR, show_training_metrics: bool = False, ) -> Algorithm: """Performs training of a Phantom experiment using the RLlib library. Any objects that inherit from BaseSampler in the env_supertype or agent_supertypes parameters will be automatically sampled from and fed back into the environment at the start of each episode. Arguments: algorithm: RL algorithm to use (optional - one of 'algorithm' or 'trainer' must be provided). env_class: A PhantomEnv subclass. policies: A mapping of policy IDs to policy configurations. iterations: Number of training iterations to perform. checkpoint_freq: The iteration frequency to save policy checkpoints at. num_workers: Number of Ray rollout workers to use (defaults to 'NUM CPU - 1'). env_config: Configuration parameters to pass to the environment init method. rllib_config: Optional algorithm parameters dictionary to pass to RLlib. metrics: Optional set of metrics to record and log. results_dir: A custom results directory, default is ~/ray_results/ show_training_metrics: Set to True to print training metrics every iteration. The ``policies`` parameter defines which agents will use which policy. This is key to performing shared policy learning. The function expects a mapping of ``{<policy_id> : <policy_setup>}``. The policy setup values can take one of the following forms: * ``Type[Agent]``: All agents that are an instance of this class will learn the same RLlib policy. * ``List[AgentID]``: All agents that have IDs in this list will learn the same RLlib policy. * ``Tuple[PolicyClass, Type[Agent]]``: All agents that are an instance of this class will use the same fixed/learnt policy. * ``Tuple[PolicyClass, Type[Agent], Mapping[str, Any]]``: All agents that are an instance of this class will use the same fixed/learnt policy configured with the given options. * ``Tuple[PolicyClass, List[AgentID]]``: All agents that have IDs in this list use the same fixed/learnt policy. * ``Tuple[PolicyClass, List[AgentID], Mapping[str, Any]]``: All agents that have IDs in this list use the same fixed/learnt policy configured with the given options. Returns: The Ray Tune experiment results object. .. note:: It is the users responsibility to invoke training via the provided ``phantom`` command or ensure the ``PYTHONHASHSEED`` environment variable is set before starting the Python interpreter to run this code. Not setting this may lead to reproducibility issues. """ show_pythonhashseed_warning() assert iterations > 0, "'iterations' parameter must be > 0" if num_workers is not None: assert num_workers >= 0, "'num_workers' parameter must be >= 0" env_config = env_config or {} rllib_config = rllib_config or {} metrics = metrics or {} check_env_config(env_config) ray.init(ignore_reinit_error=True) env = env_class(**env_config) env.reset() policy_specs: Dict[str, rllib.policy.policy.PolicySpec] = {} policy_mapping: Dict[AgentID, str] = {} policies_to_train: List[str] = [] for policy_name, params in policies.items(): policy_class = None config = None if isinstance(params, list): agent_ids = params policies_to_train.append(policy_name) elif isclass(params) and issubclass(params, Agent): agent_ids = list(env.network.get_agents_with_type(params).keys()) policies_to_train.append(policy_name) elif isinstance(params, tuple): if len(params) == 2: policy_class, agent_ids = params else: policy_class, agent_ids, config = params if issubclass(policy_class, Policy): policy_class = make_rllib_wrapped_policy_class(policy_class) if isclass(agent_ids) and issubclass(agent_ids, Agent): agent_ids = list(env.network.get_agents_with_type(agent_ids).keys()) else: raise TypeError(type(params)) policy_specs[policy_name] = rllib.policy.policy.PolicySpec( policy_class=policy_class, action_space=env.agents[agent_ids[0]].action_space, observation_space=env.agents[agent_ids[0]].observation_space, config=config, ) for agent_id in agent_ids: policy_mapping[agent_id] = policy_name def policy_mapping_fn(agent_id, *args, **kwargs): return policy_mapping[agent_id] ray.tune.registry.register_env( env_class.__name__, lambda config: RLlibEnvWrapper(env_class(**config)) ) num_workers_ = (os.cpu_count() - 1) if num_workers is None else num_workers timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") logdir_prefix = f"{algorithm}_{env.__class__.__name__}_{timestr}" results_dir = os.path.expanduser(results_dir) def logger_creator(config): os.makedirs(results_dir, exist_ok=True) logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=results_dir) return ray.tune.logger.UnifiedLogger(config, logdir, loggers=None) config = { "callbacks": RLlibMetricLogger(metrics), "enable_connectors": True, "env": env_class.__name__, "env_config": env_config, "framework": "torch", "logger_creator": logger_creator, "num_sgd_iter": 10, "num_rollout_workers": num_workers_, "rollout_fragment_length": env.num_steps, "seed": 0, "train_batch_size": env.num_steps * max(1, num_workers_), } config.update(rllib_config) config["multiagent"] = { "policies": policy_specs, "policy_mapping_fn": policy_mapping_fn, "policies_to_train": policies_to_train, } config["multiagent"].update(rllib_config.get("multiagent", {})) if algorithm == "PPO": config["sgd_minibatch_size"] = max(int(config["train_batch_size"] / 10), 1) algo = ( ray.tune.registry.get_trainable_cls(algorithm) .get_default_config() .from_dict(config) .build() ) ray.init(ignore_reinit_error=True) with rich_progress("Training...") as progress: for i in progress.track(range(iterations)): result = algo.train() if show_training_metrics: rich.pretty.pprint( { "iteration": i + 1, "metrics": result["custom_metrics"], "rewards": { "policy_reward_min": result["policy_reward_min"], "policy_reward_max": result["policy_reward_max"], "policy_reward_mean": result["policy_reward_mean"], }, } ) if i == 0: config = { "algorithm": algorithm, "env_class": env_class, "iterations": iterations, "checkpoint_freq": checkpoint_freq, "policy_specs": policy_specs, "policy_mapping": policy_mapping, "policies_to_train": policies_to_train, "env_config": env_config, "rllib_config": rllib_config, "metrics": metrics, } with open(Path(algo.logdir, "phantom-training-params.pkl"), "wb") as f: cloudpickle.dump(config, f) if checkpoint_freq is not None and i % checkpoint_freq == 0: algo.save() print(f"Logs & checkpoints saved to: {algo.logdir}") return algo
class RLlibMetricLogger(DefaultCallbacks): """RLlib callback that logs Phantom metrics.""" def __init__(self, metrics: Mapping[str, "Metric"]) -> None: super().__init__() self.metrics = metrics def on_episode_start(self, *, episode, **kwargs) -> None: for metric_id in self.metrics.keys(): episode.user_data[metric_id] = [] def on_episode_step(self, *, base_env, episode, **kwargs) -> None: env = base_env.envs[0] logging_helper(env, self.metrics, episode.user_data) def on_episode_end(self, *, episode, **kwargs) -> None: for metric_id, metric in self.metrics.items(): episode.custom_metrics[metric_id] = metric.reduce( episode.user_data[metric_id], mode="train" ) def __call__(self) -> "RLlibMetricLogger": return self def make_rllib_wrapped_policy_class(policy_class: Type[Policy]) -> Type[rllib.Policy]: class RLlibPolicyWrapper(rllib.Policy): # NOTE: # If the action space is larger than -1.0 < x < 1.0, RLlib will attempt to # 'unsquash' the values leading to potentially unintended results. # (https://github.com/ray-project/ray/pull/16531) def __init__( self, observation_space: gym.Space, action_space: gym.Space, config: Mapping[str, Any], **kwargs, ): self.policy = policy_class(observation_space, action_space, **kwargs) super().__init__(observation_space, action_space, config) def get_weights(self): return None def set_weights(self, weights): pass def learn_on_batch(self, samples): return {} def compute_single_action( self, obs: Optional[TensorStructType] = None, state: Optional[List[TensorType]] = None, *, prev_action: Optional[TensorStructType] = None, prev_reward: Optional[TensorStructType] = None, info: dict = None, input_dict: Optional[SampleBatch] = None, episode: Optional[Episode] = None, explore: Optional[bool] = None, timestep: Optional[int] = None, # Kwargs placeholder for future compatibility. **kwargs, ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]: return self.policy.compute_action(obs), [], {} def compute_actions( self, obs_batch: Union[List[TensorStructType], TensorStructType], state_batches: Optional[List[TensorType]] = None, prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, info_batch: Optional[Dict[str, list]] = None, episodes: Optional[List[MultiAgentEpisode]] = None, explore: Optional[bool] = None, timestep: Optional[int] = None, # Kwargs placeholder for future compatibility. **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: # Workaround due to known issue in RLlib # https://github.com/ray-project/ray/issues/10009 if isinstance(self.action_space, gym.spaces.Tuple): unbatched = [self.policy.compute_action(obs) for obs in obs_batch] actions = tuple( np.array([unbatched[j][i] for j in range(len(unbatched))]) for i in range(len(unbatched[0])) ) else: actions = [self.policy.compute_action(obs) for obs in obs_batch] return (actions, [], {}) return RLlibPolicyWrapper