-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathupdate_old_checkpoints.py
More file actions
156 lines (123 loc) · 5.73 KB
/
update_old_checkpoints.py
File metadata and controls
156 lines (123 loc) · 5.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from __future__ import annotations
import logging
import os
from contextlib import contextmanager
from pathlib import Path
import torch
import yaml
from modalities.config.config import load_app_config_dict
config_type = dict[str, "str | config_type"]
def update_model(old_model_config: str, new_model_config: str, new_checkpoint_path: str | None):
old_checkpoint_path = update_config(old_model_config, new_model_config, new_checkpoint_path)
test_loading_config(new_model_config)
if new_checkpoint_path is not None:
if not old_checkpoint_path:
logging.error("No valid checkpoint path found in config file!")
exit(1)
update_model_state_dict(old_checkpoint_path, new_checkpoint_path)
def update_config(old_path: str, new_path: str, new_checkpoint_path: str | None) -> str | None:
"""
Convert a configuration file from an old format to a new format.
Args:
old_path (str): Path to the old configuration file.
new_path (str): Path to save the new configuration file.
new_checkpoint_path (str | None): Path to the new checkpoint file, if applicable.
Returns:
str | None: The old checkpoint path if it was updated, otherwise None.
"""
with open(old_path, "r") as old_file:
config: config_type = yaml.safe_load(old_file)
old_checkpoint_path = update_checkpoint_path(config, new_checkpoint_path)
add_new_keys(config)
remove_keys(config)
rename_keys(config)
with open(new_path, "w") as new_file:
yaml.dump(config, new_file)
return old_checkpoint_path
def update_checkpoint_path(config: config_type, new_checkpoint_path: str | None) -> str | None:
if new_checkpoint_path is not None:
if "checkpointed_model" in config:
old_path = config["checkpointed_model"]["config"]["checkpoint_path"]
config["checkpointed_model"]["config"]["checkpoint_path"] = new_checkpoint_path
return old_path
else:
logging.error("'new_checkpoint_path' is set but no 'checkpointed_model' key found in configuration.")
exit(1)
return None
def rename_keys(config: config_type):
model_config = config["model_raw" if "model_raw" in config else "model"]["config"]
old_norm_keys = ["attention_norm", "ffn_norm", "lm_head_norm"]
new_norm_keys = ["attention_norm_config", "ffn_norm_config", "lm_head_norm_config"]
for old_key, new_key in zip(old_norm_keys, new_norm_keys):
rename_config_key(model_config, old_key, new_key)
rename_config_key(model_config[new_key], "variant_key", "norm_type")
def rename_config_key(config: config_type, old_key: str, new_key: str):
"""
Rename a single key in the configuration dictionary.
Args:
config (dict): The configuration dictionary.
old_key (str): The old key to be renamed.
new_key (str): The new key name.
"""
if old_key in config:
config[new_key] = config.pop(old_key)
else:
logging.warning(f"Key '{old_key}' not found in configuration.")
def add_new_keys(config: config_type):
model_config = config["model_raw" if "model_raw" in config else "model"]["config"]
model_config["use_weight_tying"] = False
model_config["use_meta_device"] = False
def remove_keys(config: config_type):
if "evaluation_subscriber" in config and "experiment_id" in config["evaluation_subscriber"]["config"]:
del config["evaluation_subscriber"]["config"]["experiment_id"]
if "settings" in config and "experiment_id" in config["settings"]:
del config["settings"]["experiment_id"]
if (
"checkpoint_saving" in config
and "checkpoint_saving_execution" in config["checkpoint_saving"]["config"]
and "experiment_id" in config["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"]
):
del config["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"]["experiment_id"]
def update_model_state_dict(old_model_path: str, new_model_path: str):
"""
Update the model state dictionary by loading the old model and saving it to the new path.
Args:
old_model_path (str): Path to the old model file.
new_model_path (str): Path to the new model file.
"""
state_dict = torch.load(old_model_path)
if "lm_head.weight" in state_dict:
state_dict["transformer.lm_head.weight"] = state_dict["lm_head.weight"]
del state_dict["lm_head.weight"]
torch.save(state_dict, new_model_path)
else:
logging.error("'lm_head.weight' not found in the model state dictionary.")
if "transformer.lm_head.weight" in state_dict:
logging.error("The model state dictionary already seems to be in the updated format.")
def test_loading_config(new_config_path: str):
with temporary_environ({"LOCAL_RANK": "0", "RANK": "0"}):
load_app_config_dict(Path(new_config_path))
@contextmanager
def temporary_environ(env_vars: dict[str, str]):
old_env = {}
for key, value in env_vars.items():
old_env[key] = os.environ.get(key)
os.environ[key] = value
try:
yield
finally:
for key, value in old_env.items():
if value is None:
del os.environ[key]
else:
os.environ[key] = value
if __name__ == "__main__":
import sys
if len(sys.argv) < 3:
print("Usage: python update_old_checkpoints.py <old_model_config> <new_model_config> [new_checkpoint_path]")
print("If only a config file conversion is needed, omit the third argument.")
exit(1)
old_model_config = sys.argv[1]
new_model_config = sys.argv[2]
new_checkpoint_path = sys.argv[3] if len(sys.argv) > 3 else None
update_model(old_model_config, new_model_config, new_checkpoint_path)