@@ -741,7 +741,8 @@ def train(
741741 num_batches = args .num_minibatches * args .gradient_accumulation_steps ,
742742 get_logits_and_value = partial (policy .apply , method = policy .get_logits_and_value ),
743743 impala_cfg = args .loss ,
744- )
744+ ),
745+ donate_argnames = ("agent_state" , "key" ),
745746 ),
746747 axis_name = SINGLE_DEVICE_UPDATE_DEVICES_AXIS ,
747748 devices = runtime_info .global_learner_devices ,
@@ -754,7 +755,11 @@ def train(
754755 unreplicated_params = agent_state .params
755756 key , * actor_keys = jax .random .split (key , 1 + len (args .actor_device_ids ))
756757 for d_idx , d_id in enumerate (args .actor_device_ids ):
757- device_params = jax .device_put (unreplicated_params , runtime_info .local_devices [d_id ])
758+ # Copy device_params so we can donate the agent_state in the multi_device_update
759+ device_params = jax .tree .map (
760+ partial (jnp .array , copy = True ),
761+ jax .device_put (unreplicated_params , runtime_info .local_devices [d_id ]),
762+ )
758763 for thread_id in range (args .num_actor_threads ):
759764 params_queues .append (queue .Queue (maxsize = 1 ))
760765 rollout_queues .append (queue .Queue (maxsize = 1 ))
@@ -800,15 +805,19 @@ def train(
800805
801806 key , * epoch_keys = jax .random .split (key , 1 + args .train_epochs )
802807 permutation_key = jax .random .split (epoch_keys [0 ], len (runtime_info .global_learner_devices ))
803- (agent_state , metrics_dict ) = multi_device_update (agent_state , sharded_storages , key = permutation_key )
808+ (agent_state , metrics_dict ) = multi_device_update (agent_state , sharded_storages , permutation_key )
804809 for epoch in range (1 , args .train_epochs ):
805810 permutation_key = jax .random .split (epoch_keys [epoch ], len (runtime_info .global_learner_devices ))
806- (agent_state , metrics_dict ) = multi_device_update (agent_state , sharded_storages , key = permutation_key )
811+ (agent_state , metrics_dict ) = multi_device_update (agent_state , sharded_storages , permutation_key )
807812
808813 unreplicated_params = unreplicate (agent_state .params )
809814 if update > args .actor_update_cutoff or update % args .actor_update_frequency == 0 :
810815 for d_idx , d_id in enumerate (args .actor_device_ids ):
811- device_params = jax .device_put (unreplicated_params , runtime_info .local_devices [d_id ])
816+ # Copy device_params so we can donate the agent_state in the multi_device_update
817+ device_params = jax .tree .map (
818+ partial (jnp .array , copy = True ),
819+ jax .device_put (unreplicated_params , runtime_info .local_devices [d_id ]),
820+ )
812821 for thread_id in range (args .num_actor_threads ):
813822 params_queues [d_idx * args .num_actor_threads + thread_id ].put (
814823 ParamsPayload (params = device_params , policy_version = args .learner_policy_version ),
@@ -954,11 +963,13 @@ def load_train_state(
954963 pass # must be already unreplicated
955964 if isinstance (args .net , ConvLSTMConfig ):
956965 for i in range (args .net .n_recurrent ):
957- train_state .params ["params" ]["network_params" ][f"cell_list_{ i } " ]["fence" ]["kernel" ] = jnp .sum (
958- train_state .params ["params" ]["network_params" ][f"cell_list_{ i } " ]["fence" ]["kernel" ],
959- axis = 2 ,
960- keepdims = True ,
961- )
966+ this_cell = train_state .params ["params" ]["network_params" ][f"cell_list_{ i } " ]
967+ if "fence" in this_cell :
968+ this_cell ["fence" ]["kernel" ] = jnp .sum (
969+ this_cell ["fence" ]["kernel" ],
970+ axis = 2 ,
971+ keepdims = True ,
972+ )
962973
963974 if finetune_with_noop_head :
964975 loaded_head = train_state .params ["params" ]["actor_params" ]["Output" ]
0 commit comments