Skip to content

Commit 1419cc6

Browse files
author
Adrià Garriga-Alonso
committed
Fix loading fenceless checkpoints
1 parent 891ffde commit 1419cc6

2 files changed

Lines changed: 22 additions & 11 deletions

File tree

cleanba/cleanba_impala.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,8 @@ def train(
741741
num_batches=args.num_minibatches * args.gradient_accumulation_steps,
742742
get_logits_and_value=partial(policy.apply, method=policy.get_logits_and_value),
743743
impala_cfg=args.loss,
744-
)
744+
),
745+
donate_argnames=("agent_state", "key"),
745746
),
746747
axis_name=SINGLE_DEVICE_UPDATE_DEVICES_AXIS,
747748
devices=runtime_info.global_learner_devices,
@@ -754,7 +755,11 @@ def train(
754755
unreplicated_params = agent_state.params
755756
key, *actor_keys = jax.random.split(key, 1 + len(args.actor_device_ids))
756757
for d_idx, d_id in enumerate(args.actor_device_ids):
757-
device_params = jax.device_put(unreplicated_params, runtime_info.local_devices[d_id])
758+
# Copy device_params so we can donate the agent_state in the multi_device_update
759+
device_params = jax.tree.map(
760+
partial(jnp.array, copy=True),
761+
jax.device_put(unreplicated_params, runtime_info.local_devices[d_id]),
762+
)
758763
for thread_id in range(args.num_actor_threads):
759764
params_queues.append(queue.Queue(maxsize=1))
760765
rollout_queues.append(queue.Queue(maxsize=1))
@@ -800,15 +805,19 @@ def train(
800805

801806
key, *epoch_keys = jax.random.split(key, 1 + args.train_epochs)
802807
permutation_key = jax.random.split(epoch_keys[0], len(runtime_info.global_learner_devices))
803-
(agent_state, metrics_dict) = multi_device_update(agent_state, sharded_storages, key=permutation_key)
808+
(agent_state, metrics_dict) = multi_device_update(agent_state, sharded_storages, permutation_key)
804809
for epoch in range(1, args.train_epochs):
805810
permutation_key = jax.random.split(epoch_keys[epoch], len(runtime_info.global_learner_devices))
806-
(agent_state, metrics_dict) = multi_device_update(agent_state, sharded_storages, key=permutation_key)
811+
(agent_state, metrics_dict) = multi_device_update(agent_state, sharded_storages, permutation_key)
807812

808813
unreplicated_params = unreplicate(agent_state.params)
809814
if update > args.actor_update_cutoff or update % args.actor_update_frequency == 0:
810815
for d_idx, d_id in enumerate(args.actor_device_ids):
811-
device_params = jax.device_put(unreplicated_params, runtime_info.local_devices[d_id])
816+
# Copy device_params so we can donate the agent_state in the multi_device_update
817+
device_params = jax.tree.map(
818+
partial(jnp.array, copy=True),
819+
jax.device_put(unreplicated_params, runtime_info.local_devices[d_id]),
820+
)
812821
for thread_id in range(args.num_actor_threads):
813822
params_queues[d_idx * args.num_actor_threads + thread_id].put(
814823
ParamsPayload(params=device_params, policy_version=args.learner_policy_version),
@@ -954,11 +963,13 @@ def load_train_state(
954963
pass # must be already unreplicated
955964
if isinstance(args.net, ConvLSTMConfig):
956965
for i in range(args.net.n_recurrent):
957-
train_state.params["params"]["network_params"][f"cell_list_{i}"]["fence"]["kernel"] = jnp.sum(
958-
train_state.params["params"]["network_params"][f"cell_list_{i}"]["fence"]["kernel"],
959-
axis=2,
960-
keepdims=True,
961-
)
966+
this_cell = train_state.params["params"]["network_params"][f"cell_list_{i}"]
967+
if "fence" in this_cell:
968+
this_cell["fence"]["kernel"] = jnp.sum(
969+
this_cell["fence"]["kernel"],
970+
axis=2,
971+
keepdims=True,
972+
)
962973

963974
if finetune_with_noop_head:
964975
loaded_head = train_state.params["params"]["actor_params"]["Output"]

cleanba/impala_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,11 @@ def tree_flatten_and_concat(x) -> jax.Array:
302302
def single_device_update(
303303
agent_state: TrainState,
304304
sharded_storages: List[Rollout],
305+
key: jax.Array,
305306
*,
306307
get_logits_and_value: GetLogitsAndValueFn,
307308
num_batches: int,
308309
impala_cfg: ImpalaLossConfig,
309-
key: jax.Array,
310310
) -> tuple[TrainState, dict[str, jax.Array]]:
311311
def update_minibatch(agent_state: TrainState, minibatch: Rollout):
312312
(loss, metrics_dict), grads = jax.value_and_grad(impala_cfg.loss, has_aux=True)(

0 commit comments

Comments
 (0)