Source code for phantom.utils.rllib.rollout

import math
import os
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import (
    Any,
    Dict,
    Generator,
    List,
    Mapping,
    Optional,
    Tuple,
    Type,
    Union,
)

import cloudpickle
import ray
from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor
from ray.rllib.policy import Policy as RLlibPolicy
from ray.rllib.utils.spaces.space_utils import unsquash_action
from ray.util.queue import Queue

from ...env import PhantomEnv
from ...fsm import FiniteStateMachineEnv
from ...metrics import Metric, logging_helper
from ...policy import Policy
from ...types import AgentID
from ..rollout import Rollout, Step
from .. import (
    collect_instances_of_type_with_paths,
    contains_type,
    rich_progress,
    show_pythonhashseed_warning,
    update_val,
    Range,
    Sampler,
)
from . import construct_results_paths


CustomPolicyMapping = Mapping[AgentID, Type[Policy]]


[docs]def rollout( directory: Union[str, Path], env_class: Optional[Type[PhantomEnv]] = None, env_config: Optional[Dict[str, Any]] = None, custom_policy_mapping: Optional[CustomPolicyMapping] = None, num_repeats: int = 1, num_workers: Optional[int] = None, checkpoint: Optional[int] = None, metrics: Optional[Mapping[str, Metric]] = None, record_messages: bool = False, show_progress_bar: bool = True, policy_inference_batch_size: int = 1, ) -> Generator[Rollout, None, None]: """Performs rollouts for a previously trained Phantom experiment. Any objects that inherit from the Range class in the env_config parameter will be expanded out into a multidimensional space of rollouts. For example, if two distinct UniformRanges are used, one with a length of 10 and one with a length of 5, 10 * 5 = 50 rollouts will be performed. If num_repeats is also given, say with a value of 2, then each of the 50 rollouts will be repeated twice, each time with a different random seed. Arguments: directory: Results directory containing trained policies. By default, this is located within `~/ray_results/`. If LATEST is given as the last element of the path, the parent directory will be scanned for the most recent run and this will be used. env_class: Optionally pass the Environment class to use. If not give will fallback to the copy of the environment class saved during training. env_config: Configuration parameters to pass to the environment init method. custom_policy_mapping: Optionally replace agent policies with custom fixed policies. num_workers: Number of rollout worker processes to initialise (defaults to 'NUM CPU - 1'). num_repeats: Number of rollout repeats to perform, distributed over all workers. checkpoint: Checkpoint to use (defaults to most recent). metrics: Optional set of metrics to record and log. record_messages: If True the full list of episode messages for each of the rollouts will be recorded. Only applies if `save_trajectories` is also True. show_progress_bar: If True shows a progress bar in the terminal output. policy_inference_batch_size: Number of policy inferences to perform in one go. Returns: A Generator of Rollouts. .. note:: It is the users responsibility to invoke rollouts via the provided ``phantom`` command or ensure the ``PYTHONHASHSEED`` environment variable is set before starting the Python interpreter to run this code. Not setting this may lead to reproducibility issues. """ assert num_repeats > 0, "num_repeats must be at least 1" assert ( policy_inference_batch_size > 0 ), "policy_inference_batch_size must be at least 1" if policy_inference_batch_size > 1 and issubclass(env_class, FiniteStateMachineEnv): raise ValueError("Cannot use FSM env when policy_inference_batch_size > 1") if num_workers is not None: assert num_workers >= 0, "num_workers must be at least 0" show_pythonhashseed_warning() metrics = metrics or {} env_config = env_config or {} custom_policy_mapping = custom_policy_mapping or {} if contains_type(env_config, Sampler): raise TypeError( "env_config should not contain instances of classes inheriting from BaseSampler" ) directory, checkpoint_path = construct_results_paths(directory, checkpoint) # We find all instances of objects that inherit from BaseRange in the env supertype # and agent supertypes. We keep a track of where in this structure they came from # so we can easily replace the values at a later stage. # Each Range object can have multiple paths as it can exist at multiple points within # the data structure. Eg. shared across multiple agents. ranges = collect_instances_of_type_with_paths(Range, ({}, env_config)) # This 'variations' list is where we build up every combination of the expanded # values from the list of Ranges. variations: List[List[Dict[str, Any]]] = [deepcopy([{}, env_config])] unamed_range_count = 0 # For each iteration of this outer loop we expand another Range object. for range_obj, paths in reversed(ranges): values = range_obj.values() name = range_obj.name if name is None: name = f"range-{unamed_range_count}" unamed_range_count += 1 variations2 = [] for value in values: for variation in variations: variation = deepcopy(variation) variation[0][name] = value for path in paths: update_val(variation, path, value) variations2.append(variation) variations = variations2 # Apply the number of repeats requested to each 'variation'. rollout_configs = [ _RolloutConfig( i * num_repeats + j, j, env_config, rollout_params, ) for i, (rollout_params, env_config) in enumerate(variations) for j in range(num_repeats) ] num_workers_ = (os.cpu_count() - 1) if num_workers is None else num_workers print( f"Starting {len(rollout_configs):,} rollout(s) using {num_workers_} worker process(es)" ) # Load configs from results directory. with open(Path(directory, "params.pkl"), "rb") as params_file: config = cloudpickle.load(params_file) with open(Path(directory, "phantom-training-params.pkl"), "rb") as params_file: ph_config = cloudpickle.load(params_file) if env_class is None: env_class = ph_config["env_class"] # Start the rollouts if num_workers_ == 0: # If num_workers is 0, run all the rollouts in this thread. rollouts = _rollout_task_fn( config, checkpoint_path, rollout_configs, env_class, ph_config["policy_specs"], custom_policy_mapping, policy_inference_batch_size, metrics, record_messages, ) if show_progress_bar: with rich_progress("Rollouts...") as progress: yield from progress.track(rollouts, total=len(rollout_configs)) else: yield from rollouts else: q = Queue() # Distribute the rollouts evenly amongst the number of workers. rollouts_per_worker = int( math.ceil(len(rollout_configs) / max(num_workers_, 1)) ) @ray.remote def remote_rollout_task_fn(*args): for x in _rollout_task_fn(*args): q.put(x) worker_payloads = [ ( config, checkpoint_path, rollout_configs[i : i + rollouts_per_worker], env_class, ph_config["policy_specs"], custom_policy_mapping, policy_inference_batch_size, metrics, record_messages, ) for i in range(0, len(rollout_configs), rollouts_per_worker) ] for payload in worker_payloads: remote_rollout_task_fn.remote(*payload) if show_progress_bar: with rich_progress("Rollouts...") as progress: for _ in progress.track(range(len(rollout_configs))): yield q.get() else: for _ in range(len(rollout_configs)): yield q.get()
def _rollout_task_fn( config, checkpoint_path: Path, all_configs: List["_RolloutConfig"], env_class: Type[PhantomEnv], policy_specs, custom_policy_mapping: CustomPolicyMapping, policy_inference_batch_size: int, metric_objects: Optional[Mapping[str, Metric]] = None, record_messages: bool = False, ) -> Generator[Rollout, None, None]: """Internal function""" def chunker(seq, size): return (seq[pos : pos + size] for pos in range(0, len(seq), size)) # Lazily load checkpointed policy objects saved_policies: Dict[str, Tuple[RLlibPolicy, Preprocessor]] = {} # Setting seed needs to come after algo setup ray.rllib.utils.debug.update_global_seed_if_necessary( config.framework_str, all_configs[0].rollout_id ) for configs in chunker(all_configs, policy_inference_batch_size): batch_size = len(configs) vec_envs = [ env_class(**rollout_config.env_config) for rollout_config in configs ] if record_messages: for env in vec_envs: env.network.resolver.enable_tracking = True vec_metrics = [defaultdict(list) for _ in range(batch_size)] vec_all_steps = [[] for _ in range(batch_size)] vec_observations = [ env.reset(seed=config.rollout_id)[0] for env, config in zip(vec_envs, configs) ] initted_policy_mapping = { agent_id: policy( vec_envs[0][agent_id].observation_space, vec_envs[0][agent_id].action_space, ) for agent_id, policy in custom_policy_mapping.items() } # Run rollout steps. for i in range(vec_envs[0].num_steps): actions = {} dict_observations = { k: [dic[k] for dic in vec_observations] for k in vec_observations[0] } for agent_id, vec_agent_obs in dict_observations.items(): if agent_id in initted_policy_mapping: actions[agent_id] = [ initted_policy_mapping[agent_id].compute_action(agent_obs) for agent_obs in vec_agent_obs ] else: policy_id = config.policy_mapping_fn(agent_id, 0, 0) if policy_id not in saved_policies: policy = RLlibPolicy.from_checkpoint( checkpoint_path / "policies" / policy_id ) obs_s = policy_specs[policy_id].observation_space preprocessor = get_preprocessor(obs_s)(obs_s) saved_policies[policy_id] = (policy, preprocessor) else: policy, preprocessor = saved_policies[policy_id] processed_obs = [preprocessor.transform(ob) for ob in vec_agent_obs] squashed_actions = policy.compute_actions( processed_obs, explore=False )[0] actions[agent_id] = [ unsquash_action(action, policy.action_space_struct) for action in squashed_actions ] # hack for no agent acting step in Ops if len(dict_observations) == 0: vec_actions = [{}] * batch_size else: vec_actions = [dict(zip(actions, t)) for t in zip(*actions.values())] vec_steps = [ env.step(actions) for env, actions in zip(vec_envs, vec_actions) ] for j in range(batch_size): if metric_objects is not None: logging_helper(vec_envs[j], metric_objects, vec_metrics[j]) if record_messages: messages = deepcopy(vec_envs[j].network.resolver.tracked_messages) vec_envs[j].network.resolver.clear_tracked_messages() else: messages = None vec_all_steps[j].append( Step( i, vec_observations[j], vec_steps[j].rewards, vec_steps[j].terminations, vec_steps[j].truncations, vec_steps[j].infos, vec_actions[j], messages, vec_envs[j].previous_stage if isinstance(vec_envs[j], FiniteStateMachineEnv) else None, ) ) vec_observations = [step.observations for step in vec_steps] for j in range(batch_size): reduced_metrics = { metric_id: metric_objects[metric_id].reduce( vec_metrics[j][metric_id], "evaluate" ) for metric_id in metric_objects } yield Rollout( configs[j].rollout_id, configs[j].repeat_id, configs[j].env_config, configs[j].rollout_params, vec_all_steps[j], reduced_metrics, ) @dataclass(frozen=True) class _RolloutConfig: """Internal class""" rollout_id: int repeat_id: int env_config: Mapping[str, Any] rollout_params: Mapping[str, Any]