Skip to content

Commit 1a232aa

Browse files
committed
feat: dedicated diiagnostics dir & more robust CudaCallback
1 parent 3bd2d67 commit 1a232aa

4 files changed

Lines changed: 56 additions & 39 deletions

File tree

dmlcloud/core/callbacks/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def pre_run(self, pipe: 'Pipeline'):
3535
dml_checkpoint.create_checkpoint_dir(self.run_dir)
3636
dml_checkpoint.save_config(pipe.config, self.run_dir)
3737

38-
with open(pipe.run_dir / "environment.txt", 'w') as f:
38+
with open(pipe.run_dir / 'diagnostics' / 'environment.txt', 'w') as f:
3939
for k, v in os.environ.items():
4040
f.write(f"{k}={v}\n")
4141

dmlcloud/core/callbacks/cuda.py

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,55 +8,69 @@
88
from .common import Callback
99

1010

11+
def _call_pynvml(method, *args, **kwargs):
12+
try:
13+
return method(*args, **kwargs)
14+
except pynvml.NVMLError:
15+
return None
16+
17+
18+
def _get_pynvml_handler(device):
19+
try:
20+
return torch.cuda._get_pynvml_handler(device)
21+
except pynvml.NVMLError:
22+
return None
23+
24+
25+
def _gather_cuda_info(handler):
26+
info = {
27+
'name': _call_pynvml(pynvml.nvmlDeviceGetName, handler),
28+
'uuid': _call_pynvml(pynvml.nvmlDeviceGetUUID, handler),
29+
'serial': _call_pynvml(pynvml.nvmlDeviceGetSerial, handler),
30+
'minor_number': _call_pynvml(pynvml.nvmlDeviceGetMinorNumber, handler),
31+
'architecture': _call_pynvml(pynvml.nvmlDeviceGetArchitecture, handler),
32+
'brand': _call_pynvml(pynvml.nvmlDeviceGetBrand, handler),
33+
'vbios_version': _call_pynvml(pynvml.nvmlDeviceGetVbiosVersion, handler),
34+
'driver_version': _call_pynvml(pynvml.nvmlSystemGetDriverVersion),
35+
'cuda_driver_version': _call_pynvml(pynvml.nvmlSystemGetCudaDriverVersion_v2),
36+
'nvml_version': _call_pynvml(pynvml.nvmlSystemGetNVMLVersion),
37+
'total_memory': _call_pynvml(pynvml.nvmlDeviceGetMemoryInfo, handler, pynvml.nvmlMemory_v2).total,
38+
'reserved_memory': _call_pynvml(pynvml.nvmlDeviceGetMemoryInfo, handler, pynvml.nvmlMemory_v2).reserved,
39+
'num_gpu_cores': _call_pynvml(pynvml.nvmlDeviceGetNumGpuCores, handler),
40+
'power_managment_limit': _call_pynvml(pynvml.nvmlDeviceGetPowerManagementLimit, handler),
41+
'power_managment_default_limit': _call_pynvml(pynvml.nvmlDeviceGetPowerManagementDefaultLimit, handler),
42+
'cuda_compute_capability': _call_pynvml(pynvml.nvmlDeviceGetCudaComputeCapability, handler),
43+
}
44+
return info
45+
46+
1147
class CudaCallback(Callback):
1248
"""
1349
Logs various properties pertaining to CUDA devices.
1450
"""
1551

16-
@staticmethod
17-
def _call_pynvml(method, *args, **kwargs):
18-
try:
19-
return method(*args, **kwargs)
20-
except pynvml.NVMLError:
21-
return None
22-
2352
def pre_run(self, pipe):
24-
handle = torch.cuda._get_pynvml_handler(pipe.device)
25-
26-
info = {
27-
'name': self._call_pynvml(pynvml.nvmlDeviceGetName, handle),
28-
'uuid': self._call_pynvml(pynvml.nvmlDeviceGetUUID, handle),
29-
'serial': self._call_pynvml(pynvml.nvmlDeviceGetSerial, handle),
30-
'torch_device': str(pipe.device),
31-
'minor_number': self._call_pynvml(pynvml.nvmlDeviceGetMinorNumber, handle),
32-
'architecture': self._call_pynvml(pynvml.nvmlDeviceGetArchitecture, handle),
33-
'brand': self._call_pynvml(pynvml.nvmlDeviceGetBrand, handle),
34-
'vbios_version': self._call_pynvml(pynvml.nvmlDeviceGetVbiosVersion, handle),
35-
'driver_version': self._call_pynvml(pynvml.nvmlSystemGetDriverVersion),
36-
'cuda_driver_version': self._call_pynvml(pynvml.nvmlSystemGetCudaDriverVersion_v2),
37-
'nvml_version': self._call_pynvml(pynvml.nvmlSystemGetNVMLVersion),
38-
'total_memory': self._call_pynvml(pynvml.nvmlDeviceGetMemoryInfo, handle, pynvml.nvmlMemory_v2).total,
39-
'reserved_memory': self._call_pynvml(pynvml.nvmlDeviceGetMemoryInfo, handle, pynvml.nvmlMemory_v2).reserved,
40-
'num_gpu_cores': self._call_pynvml(pynvml.nvmlDeviceGetNumGpuCores, handle),
41-
'power_managment_limit': self._call_pynvml(pynvml.nvmlDeviceGetPowerManagementLimit, handle),
42-
'power_managment_default_limit': self._call_pynvml(pynvml.nvmlDeviceGetPowerManagementDefaultLimit, handle),
43-
'cuda_compute_capability': self._call_pynvml(pynvml.nvmlDeviceGetCudaComputeCapability, handle),
44-
}
45-
all_devices = all_gather_object(info)
53+
handler = _get_pynvml_handler(pipe.device)
54+
info = _gather_cuda_info(handler) if handler is not None else {}
55+
info['torch_device'] = str(pipe.device)
56+
57+
all_infos = all_gather_object(info)
4658

4759
msg = '* CUDA-DEVICES:\n'
48-
info_strings = [
49-
f'{info["torch_device"]} -> /dev/nvidia{info["minor_number"]} -> {info["name"]} (UUID: {info["uuid"]}) (VRAM: {info["total_memory"] / 1000 ** 2:.0f} MB)'
50-
for info in all_devices
51-
]
60+
info_strings = []
61+
for info in all_infos:
62+
if 'minor_number' in info and 'name' in info and 'uuid' in info:
63+
info_strings.append(
64+
f'{info["torch_device"]} -> /dev/nvidia{info["minor_number"]} -> {info["name"]} (UUID: {info["uuid"]}) (VRAM: {info["total_memory"] / 1000 ** 2:.0f} MB)'
65+
)
5266
msg += '\n'.join(f' - [{i}] {info_str}' for i, info_str in enumerate(info_strings))
5367
dml_logging.info(msg)
5468

5569
if pipe.run_dir and is_root():
56-
self._save(pipe.run_dir / 'cuda_devices.json', all_devices)
70+
self._save(pipe.run_dir / 'diagnostics' / 'cuda_devices.json', all_infos)
5771

58-
def _save(self, path, all_devices):
72+
def _save(self, path, all_infos):
5973
with open(path, 'w') as f:
60-
devices = {f'rank_{i}': device for i, device in enumerate(all_devices)}
61-
obj = {'devices': devices}
74+
dct = {f'rank_{i}': info for i, info in enumerate(all_infos)}
75+
obj = {'devices': dct}
6276
json.dump(obj, f, indent=4)

dmlcloud/core/callbacks/git.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def pre_run(self, pipe):
3838
return
3939

4040
if pipe.run_dir and is_root():
41-
self._save(pipe.run_dir / 'git_diff.txt', diff)
41+
self._save(pipe.run_dir / 'diagnostics' / 'git_diff.txt', diff)
4242

4343
self._log_diff(diff)
4444

dmlcloud/core/checkpoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def create_checkpoint_dir(path: Path | str, name: Optional[str] = None) -> Path:
6363
indicator_file = path / '.dmlcloud'
6464
indicator_file.touch()
6565

66+
diagnostics_dir = path / 'diagnostics'
67+
diagnostics_dir.mkdir(exist_ok=True)
68+
6669
if slurm_job_id() is not None:
6770
with open(path / '.slurm-jobid', 'w') as f:
6871
f.write(slurm_job_id())

0 commit comments

Comments
 (0)