diff --git a/dmlcloud/core/distributed.py b/dmlcloud/core/distributed.py index 2708a9e..925f59d 100644 --- a/dmlcloud/core/distributed.py +++ b/dmlcloud/core/distributed.py @@ -277,7 +277,7 @@ def _init_process_group_slurm(port=DEFAULT_PORT, **kwargs): tasks_per_node.extend([int(ntasks)] * int(nnodes[:-1])) else: tasks_per_node.append(int(t)) - local_world_size = tasks_per_node[_WorkerInfo.NODE_ID] + local_world_size = tasks_per_node[node_id] _initialize_via_tcp( ip=ip,