Skip to content

Remove deepcopy:pickle error, allow non-falsy num_workers=0#237

Open
gpwolfe wants to merge 1 commit into
openkim:mainfrom
gpwolfe:main
Open

Remove deepcopy:pickle error, allow non-falsy num_workers=0#237
gpwolfe wants to merge 1 commit into
openkim:mainfrom
gpwolfe:main

Conversation

@gpwolfe
Copy link
Copy Markdown

@gpwolfe gpwolfe commented May 20, 2026

Summary

Include a summary of major changes in bullet points:

  • deepcopy of pl_model removed. deepcopy fails on e3nn layers with pickle error.
  • Keep model's float type (32 or 64) instead of forcing to RadialGraph C default float64.
  • num_workers == 0 will no longer fall through to num_workers = os.getenv("SLURM_CPUS_PER_TASK", 1) due to evaluation as False. Allows non-multithread support on non-slurm system.

Additional dependencies introduced (if any)

  • No new dependencies

Checklist

Before a pull request can be merged, the following items must be checked:

  • Make sure your code is properly formatted. isort and black are used for this purpose. The simplest way is to use pre-commit. See instructions here.
  • Doc strings have been added in the Google docstring format on your code.
  • Type annotations are highly encouraged. Run mypy to
    type check your code.
  • Tests have been added for any new functionality or bug fixes.
  • [-] All linting and tests pass. Pre-existing error on non-torch installs from type check.

Note that the CI system will run all the above checks. But it will be much more
efficient if you already fix most errors prior to submitting the PR.

@mjwen
Copy link
Copy Markdown
Collaborator

mjwen commented May 21, 2026

@ipcamit can you please review this PR?

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the PyTorch Lightning trainer to (1) avoid deepcopy failures when exporting models that include e3nn layers, (2) preserve/normalize graph tensor dtypes rather than always ending up with float64 from the RadialGraph C-extension, and (3) correctly honor num_workers=0 instead of treating it as falsy and falling back to SLURM defaults.

Changes:

  • Remove deepcopy(self.pl_model) during export and load the best checkpoint directly into self.pl_model before TorchScripting.
  • Add a post-transform dtype normalization step for precomputed graph fingerprints (coords/forces) when default dtype is float32.
  • Fix num_workers resolution so explicit 0 is respected and values are consistently cast to int.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +362 to +374
# of input dtype. Cast fingerprints back to match the model's default dtype.
if torch.get_default_dtype() == torch.float32:
for ds in [self.train_dataset, self.val_dataset]:
if ds is None:
continue
for config in ds:
fp = getattr(config, "fingerprint", None)
if fp is None:
continue
if hasattr(fp, "coords") and fp.coords is not None:
fp.coords = fp.coords.to(torch.float32)
if hasattr(fp, "forces") and fp.forces is not None:
fp.forces = fp.forces.to(torch.float32)
Comment on lines 349 to +365
if self.dataset_manifest["dynamic_loading"]:
transform = self.configuration_transform
else:
transform = None

if not transform:
for config in self.train_dataset:
config.fingerprint = self.configuration_transform(config)
if self.val_dataset:
for config in self.val_dataset:
config.fingerprint = self.configuration_transform(config)

# RadialGraph C extension always upcasts coords/forces to float64 regardless
# of input dtype. Cast fingerprints back to match the model's default dtype.
if torch.get_default_dtype() == torch.float32:
for ds in [self.train_dataset, self.val_dataset]:
if ds is None:
Comment on lines +520 to 526
self.pl_model.load_state_dict(
torch.load(
f"{self.current['run_dir']}/checkpoints/best_model.pth",
weights_only=False,
)
)
try:
# so mutating self.pl_model is safe.
self.pl_model.load_state_dict(
torch.load(
f"{self.current['run_dir']}/checkpoints/best_model.pth",
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants