adding model-only restoration to RL path#107
Conversation
Signed-off-by: Thorsten Kurth <tkurth@nvidia.com>
Signed-off-by: Thorsten Kurth <tkurth@nvidia.com>
|
/build_and_test |
|
🚀 Build workflow triggered! View run |
…sing load model Signed-off-by: Thorsten Kurth <tkurth@nvidia.com>
|
/build_and_test |
|
🚀 Build workflow triggered! View run |
|
✅ Build workflow passed! View run |
|
✅ Build workflow passed! View run |
romerojosh
left a comment
There was a problem hiding this comment.
Looks mostly good to me. Just left a small comment about the added tests.
| CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(full.c_str(), full_ready_after)); | ||
| EXPECT_TRUE(full_ready_after); | ||
|
|
||
| // ---------------- weights-only restore (load_model) ---------------- |
There was a problem hiding this comment.
It would be good to check that training works after the weights-only reload, in particular, that the optimizer is reattached to the right weight tensors. Running a training step after the weights are loaded and checking that the output is modified I think would be sufficient for this.
I am missing load/save_model tests altogether for the supervised case so that is something I'll need to address in a different PR.
There was a problem hiding this comment.
I had some device placement bugs there, all fixed I think. Added the train tests too but I need to do more than one step because some weights only update after a number of steps due to policy lag, but that still verifies that the weights are actually changing.
|
Another thing that came to mind here is that you'll want to add your new test_checkpoint binary to https://github.com/NVIDIA/TorchFort/blob/tkurth/inference-fixes/.github/scripts/run_ci_tests.sh. Annoyingly, this won't impact the workflow until this change is merged into master but we will want it in there. |
…irectory path error in test Signed-off-by: Thorsten Kurth <tkurth@nvidia.com>
Signed-off-by: Thorsten Kurth <tkurth@nvidia.com>
romerojosh
left a comment
There was a problem hiding this comment.
Updates look good. You still need to add test_checkpoint_rl to .github/scripts/run_ci_tests.sh so this test gets run in the CI, but otherwise, we are good to go.
This MR adds/changes the following things:
Weights-only model restore for RL (fine-tuning support)
Adds a coarse "load just the network weights" path for RL systems, mirroring the supervised load_model / load_checkpoint split, so a pretrained model can seed a new training run (e.g. modified reward or new
data) without dragging along the old optimizer state, replay/rollout buffer, normalizers, or step counters.
New API
Semantics
verbatim for faithful resume.)
Tests
a full restore but false after weights-only (buffer not restored) — the first RL checkpoint round-trip coverage in the repo.
Test infrastructure (future-proofing)
directory / either tree.
Cleanup
the single-param-group assumption via THROW_NOT_SUPPORTED.