Skip to content

Commit 9ebf2b3

Browse files
committed
chore: formatting
1 parent 333af21 commit 9ebf2b3

3 files changed

Lines changed: 17 additions & 11 deletions

File tree

dmlcloud/core/distributed.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,12 @@ def _initialize_via_tcp(
105105
msg += f'\n master port: {port}'
106106
print(msg, flush=True)
107107

108-
# TODO: Add check that ip == rank0 host
108+
# TODO: Add check that ip == rank0 host
109109

110110
store = dist.TCPStore(
111-
host_name=ip,
111+
host_name=ip,
112112
port=port,
113-
world_size=_WorkerInfo.WORLD_SIZE,
113+
world_size=_WorkerInfo.WORLD_SIZE,
114114
is_master=_WorkerInfo.RANK == 0,
115115
)
116116

@@ -144,7 +144,6 @@ def has_environment():
144144
return 'MASTER_PORT' in os.environ
145145

146146

147-
148147
def rank():
149148
"""
150149
Returns the rank of the current process.
@@ -200,7 +199,6 @@ def local_node():
200199
return _WorkerInfo.NODE_ID
201200

202201

203-
204202
def _init_process_group_env(**kwargs):
205203
"""
206204
Intialize using "env://" method.
@@ -210,7 +208,7 @@ def _init_process_group_env(**kwargs):
210208
"""
211209
if not has_environment():
212210
raise RuntimeError('Environment variables for env:// initialization not found')
213-
211+
214212
_initialize_via_tcp(
215213
ip=os.environ['MASTER_ADDR'],
216214
port=int(os.environ['MASTER_PORT']),
@@ -356,7 +354,6 @@ def deinitialize_torch_distributed(fail_silently=False):
356354
dist.destroy_process_group()
357355

358356

359-
360357
def is_root(group: dist.ProcessGroup = None):
361358
"""
362359
Check if the current rank is the root rank (rank 0).
@@ -518,7 +515,6 @@ def root_first(group: dist.ProcessGroup = None):
518515
pass
519516

520517

521-
522518
def all_gather_object(obj, group=None):
523519
"""
524520
Gather objects from all ranks in the group.

dmlcloud/core/pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def current_stage() -> Stage | None:
6666
return current_pipe().current_stage
6767

6868

69-
def log_metric(name: str, value: Any, reduction: str = 'mean', prefixed: bool = True, ignore_missing_stage: bool = True):
69+
def log_metric(
70+
name: str, value: Any, reduction: str = 'mean', prefixed: bool = True, ignore_missing_stage: bool = True
71+
):
7072
"""
7173
Shorthand for current_stage().log
7274

dmlcloud/core/stage.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,21 @@ def add_callback(self, callback: 'Callback', priority: int = 1):
167167
"""
168168
self.callbacks.append(callback, priority)
169169

170-
def log(self, name: str, value: Any, reduction: str = 'mean', prefixed: bool = True, log_step: bool = True, synchronize: bool = True):
170+
def log(
171+
self,
172+
name: str,
173+
value: Any,
174+
reduction: str = 'mean',
175+
prefixed: bool = True,
176+
log_step: bool = True,
177+
synchronize: bool = True,
178+
):
171179
"""
172180
Logs a metric for the current step and epoch.
173181
174182
If `synchronize` is True, the metric will be (all-)reduced across distributed processes before being logged.
175183
Care must be taken to ensure that every process participates in this reduction to avoid hangs and failures.
176-
184+
177185
Args:
178186
name (str): The name of the metric.
179187
value (Any): The value of the metric.

0 commit comments

Comments
 (0)