from abc import abstractmethod, ABC
from .context import Context
[docs]class RewardFunction(ABC):
"""
A trait for types that can compute rewards from a local context.
Note: this trait only support scalar rewards for the time being.
"""
[docs] @abstractmethod
def reward(self, ctx: Context) -> float:
"""Compute the reward from context.
Arguments:
ctx: The local network context.
"""
raise NotImplementedError
[docs] def reset(self):
"""Resets the reward function."""
[docs]class Constant(RewardFunction):
"""
A reward function that always returns a given constant.
Attributes:
value: The reward to be returned in any state.
"""
def __init__(self, value: float = 0.0) -> None:
self.value = value
[docs] def reward(self, _: Context) -> float:
return self.value