diff --git a/ajet/backbone/main_verl.py b/ajet/backbone/main_verl.py index 24db9cd..05ec0c2 100644 --- a/ajet/backbone/main_verl.py +++ b/ajet/backbone/main_verl.py @@ -27,6 +27,7 @@ from torch.utils.data import Dataset as TorchDataset # Create training and validation datasets. +from ajet.backbone.warm_up import warm_up_process from ajet.task_reader import RouterTaskReader, task_to_standard_dataset from ajet.utils.process_dataset import create_rl_sampler from ajet.utils.core_env_vars import get_runtime_env @@ -116,6 +117,7 @@ def run(self, config): from loguru import logger from omegaconf import OmegaConf from verl.utils.fs import copy_to_local + warm_up_process(config) logger.info(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") pprint(OmegaConf.to_container(config, resolve=True))