11"""Wrapper for reward labeling for transitions sampled from a replay buffer."""
22
3+ from typing import Callable
34from typing import Mapping , Type
45
56import numpy as np
1011from imitation .rewards .reward_function import RewardFn
1112from imitation .util import util
1213from imitation .util .networks import RunningNorm
13- from typing import Callable
1414
1515
1616def _samples_to_reward_fn_input (
@@ -59,6 +59,7 @@ def __init__(
5959 * ,
6060 replay_buffer_class : Type [ReplayBuffer ],
6161 reward_fn : RewardFn ,
62+ on_initialized_callback : Callable [["ReplayBufferRewardWrapper" ], None ] = None ,
6263 ** kwargs ,
6364 ):
6465 """Builds ReplayBufferRewardWrapper.
@@ -69,6 +70,9 @@ def __init__(
6970 action_space: Action space
7071 replay_buffer_class: Class of the replay buffer.
7172 reward_fn: Reward function for reward relabeling.
73+ on_initialized_callback: Callback called with reference to this object after
74+ this instance is fully initialized. This provides a hook to access the
75+ buffer after it is created from inside a Stable Baselines algorithm.
7276 **kwargs: keyword arguments for ReplayBuffer.
7377 """
7478 # Note(yawen-d): we directly inherit ReplayBuffer and leave out the case of
@@ -86,6 +90,8 @@ def __init__(
8690 self .reward_fn = reward_fn
8791 _base_kwargs = {k : v for k , v in kwargs .items () if k in ["device" , "n_envs" ]}
8892 super ().__init__ (buffer_size , observation_space , action_space , ** _base_kwargs )
93+ if on_initialized_callback is not None :
94+ on_initialized_callback (self )
8995
9096 # TODO(juan) remove the type ignore once the merged PR
9197 # https://github.com/python/mypy/pull/13475
0 commit comments