Skip to content

Commit 62157e6

Browse files
author
Jan Michelfeit
committed
#625 fix more pre-commit errors
1 parent c787877 commit 62157e6

6 files changed

Lines changed: 10 additions & 5 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""PEBBLE specific algorithms."""

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class PebbleRewardPhase(Enum):
2525
class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
2626
"""Reward function for implementation of the PEBBLE learning algorithm.
2727
28-
See https://arxiv.org/pdf/2106.05091.pdf .
28+
See https://arxiv.org/abs/2106.05091 .
2929
3030
The rewards returned by this function go through the three phases:
3131
1. Before enough samples are collected for entropy calculation, the

src/imitation/algorithms/preference_comparisons.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def __init__(
363363

364364
def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:
365365
self.train(steps, **kwargs)
366-
self.reward_fn.unsupervised_exploration_finish() # type: ignore[attribute-error]
366+
self.reward_fn.unsupervised_exploration_finish() # type: ignore
367367

368368

369369
def _get_trajectories(

src/imitation/policies/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class SAC1024Policy(sac_policies.SACPolicy):
7676
"""Actor and value networks with two hidden layers of 1024 units respectively.
7777
7878
This matches the implementation of SAC policies in the PEBBLE paper. See:
79-
https://arxiv.org/pdf/2106.05091.pdf
79+
https://arxiv.org/abs/2106.05091
8080
https://github.com/denisyarats/pytorch_sac/blob/master/config/agent/sac.yaml
8181
8282
Note: This differs from stable_baselines3 SACPolicy by having 1024 hidden units

src/imitation/rewards/reward_function.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import numpy as np
77

8+
import imitation.policies.replay_buffer_wrapper
9+
810

911
class RewardFn(Protocol):
1012
"""Abstract class for reward function.
@@ -40,6 +42,8 @@ class ReplayBufferAwareRewardFn(RewardFn, abc.ABC):
4042
@abc.abstractmethod
4143
def on_replay_buffer_initialized(
4244
self,
43-
replay_buffer: "ReplayBufferRewardWrapper", # type: ignore[name-defined] # noqa
45+
replay_buffer: (
46+
"imitation.policies.replay_buffer_wrapper.ReplayBufferRewardWrapper"
47+
),
4448
):
4549
pass

src/imitation/scripts/config/train_preference_comparisons.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def train_defaults():
6161
checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only)
6262
query_schedule = "hyperbolic"
6363

64-
# Whether to use the PEBBLE algorithm (https://arxiv.org/pdf/2106.05091.pdf)
64+
# Whether to use the PEBBLE algorithm (https://arxiv.org/abs/2106.05091)
6565
pebble_enabled = False
6666
unsupervised_agent_pretrain_frac = 0.0
6767

0 commit comments

Comments
 (0)