diff --git a/examples/online/main_isaaclab_onpolicy.py b/examples/online/main_isaaclab_onpolicy.py index dd69bac..855101d 100644 --- a/examples/online/main_isaaclab_onpolicy.py +++ b/examples/online/main_isaaclab_onpolicy.py @@ -116,7 +116,11 @@ def collect_rollouts(self) -> RolloutBatch: jnp.array(obs_norm), deterministic=False ) actions_np = np.array(actions) - actions_clipped = np.clip(actions_np, -1.0, 1.0) + actions_env = ( + np.clip(actions_np, -1.0, 1.0) + if self.env.action_bound is not None + else actions_np + ) all_actions[t] = actions_np # Generically collect algorithm-specific info @@ -128,7 +132,7 @@ def collect_rollouts(self) -> RolloutBatch: for k, v in info.items(): all_extras[k][t] = np.array(v) - next_obs, rewards, terminated, truncated, infos = self.env.step(actions_clipped) + next_obs, rewards, terminated, truncated, infos = self.env.step(actions_env) all_rewards[t] = rewards[..., np.newaxis] all_terminated[t] = terminated[..., np.newaxis] @@ -209,8 +213,13 @@ def eval_and_save(self): actions, _ = self.agent.sample_actions( jnp.array(obs_norm), deterministic=True ) - actions_clipped = np.clip(np.array(actions), -1.0, 1.0) - obs, rewards, terminated, truncated, _ = self.env.step(actions_clipped) + actions_np = np.array(actions) + actions_env = ( + np.clip(actions_np, -1.0, 1.0) + if self.env.action_bound is not None + else actions_np + ) + obs, rewards, terminated, truncated, _ = self.env.step(actions_env) eval_returns += rewards * (1 - eval_dones) eval_lengths += 1 * (1 - eval_dones) diff --git a/flowrl/config/online/onpolicy_isaaclab_config.py b/flowrl/config/online/onpolicy_isaaclab_config.py index cdc6dd6..86abd40 100644 --- a/flowrl/config/online/onpolicy_isaaclab_config.py +++ b/flowrl/config/online/onpolicy_isaaclab_config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, Optional from .hb_config import EvalConfig, LogConfig @@ -10,7 +10,7 @@ class Config: device: str task: str algo: Any - action_bound: float + action_bound: Optional[float] disable_bootstrap: bool norm_obs: bool train_frames: int