Skip to content

adding model-only restoration to RL path#107

Draft
azrael417 wants to merge 5 commits into
masterfrom
tkurth/inference-fixes
Draft

adding model-only restoration to RL path#107
azrael417 wants to merge 5 commits into
masterfrom
tkurth/inference-fixes

Conversation

@azrael417
Copy link
Copy Markdown
Collaborator

This MR adds/changes the following things:

Weights-only model restore for RL (fine-tuning support)

Adds a coarse "load just the network weights" path for RL systems, mirroring the supervised load_model / load_checkpoint split, so a pretrained model can seed a new training run (e.g. modified reward or new
data) without dragging along the old optimizer state, replay/rollout buffer, normalizers, or step counters.

New API

  • torchfort_rl_off_policy_load_model(name, checkpoint_dir) — TD3 / DDPG / SAC
  • torchfort_rl_on_policy_load_model(name, checkpoint_dir) — PPO
  • C (torchfort_rl.h) + Fortran (torchfort_m.F90) bindings, plus docs in c_api.rst, f_api.rst, and a "Fine-Tuning / Transfer Learning" section in usage.rst.

Semantics

  • Restores only the online policy/critic (PPO: actor-critic) weights from a saved checkpoint; optimizers, schedulers, buffers, normalizers and counters stay freshly initialized.
  • Off-policy target networks are re-initialized from the loaded online networks (θ′ ← θ), matching the standard init convention — not restored from the checkpoint. (load_checkpoint still restores everything
    verbatim for faithful resume.)

Tests

  • New tests/rl/test_checkpoint.cpp (test_checkpoint_rl target) with round-trip tests for all four algorithms, asserting weights match after both load_checkpoint and load_model, and that is_ready is true after
    a full restore but false after weights-only (buffer not restored) — the first RL checkpoint round-trip coverage in the repo.

Test infrastructure (future-proofing)

  • Config files are now resolved relative to the test executable (get_config_path in test_utils.h) instead of the current working directory, plus a build-tree copy of configs/, so RL tests run from any
    directory / either tree.
  • test_utils.h made self-contained.

Cleanup

  • Replaced the deprecated torch::optim::Optimizer::parameters() accessor (libtorch W603 warning) with a shared reset_optimizer_parameters() helper across all 9 RL + supervised load sites; the helper asserts
    the single-param-group assumption via THROW_NOT_SUPPORTED.

Signed-off-by: Thorsten Kurth <tkurth@nvidia.com>
@azrael417 azrael417 requested a review from romerojosh June 3, 2026 09:45
@azrael417 azrael417 self-assigned this Jun 3, 2026
@azrael417 azrael417 mentioned this pull request Jun 3, 2026
@azrael417 azrael417 marked this pull request as draft June 3, 2026 09:46
Signed-off-by: Thorsten Kurth <tkurth@nvidia.com>
@azrael417
Copy link
Copy Markdown
Collaborator Author

/build_and_test

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jun 3, 2026

🚀 Build workflow triggered! View run

…sing load model

Signed-off-by: Thorsten Kurth <tkurth@nvidia.com>
@azrael417 azrael417 changed the title adding model only restoration to RL path adding model-only restoration to RL path Jun 3, 2026
@azrael417
Copy link
Copy Markdown
Collaborator Author

/build_and_test

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jun 3, 2026

🚀 Build workflow triggered! View run

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jun 3, 2026

✅ Build workflow passed! View run

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jun 3, 2026

✅ Build workflow passed! View run

Copy link
Copy Markdown
Collaborator

@romerojosh romerojosh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good to me. Just left a small comment about the added tests.

CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(full.c_str(), full_ready_after));
EXPECT_TRUE(full_ready_after);

// ---------------- weights-only restore (load_model) ----------------
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to check that training works after the weights-only reload, in particular, that the optimizer is reattached to the right weight tensors. Running a training step after the weights are loaded and checking that the output is modified I think would be sufficient for this.

I am missing load/save_model tests altogether for the supervised case so that is something I'll need to address in a different PR.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had some device placement bugs there, all fixed I think. Added the train tests too but I need to do more than one step because some weights only update after a number of steps due to policy lag, but that still verifies that the weights are actually changing.

@romerojosh
Copy link
Copy Markdown
Collaborator

Another thing that came to mind here is that you'll want to add your new test_checkpoint binary to https://github.com/NVIDIA/TorchFort/blob/tkurth/inference-fixes/.github/scripts/run_ci_tests.sh. Annoyingly, this won't impact the workflow until this change is merged into master but we will want it in there.

azrael417 added 2 commits June 3, 2026 23:24
…irectory path error in test

Signed-off-by: Thorsten Kurth <tkurth@nvidia.com>
Signed-off-by: Thorsten Kurth <tkurth@nvidia.com>
Copy link
Copy Markdown
Collaborator

@romerojosh romerojosh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updates look good. You still need to add test_checkpoint_rl to .github/scripts/run_ci_tests.sh so this test gets run in the CI, but otherwise, we are good to go.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants