Skip to content

Commit 6406fc7

Browse files
committed
add: support sampling and warmup instrumentation policies
1 parent 53d90ca commit 6406fc7

9 files changed

Lines changed: 271 additions & 7 deletions

File tree

traincheck/collect_trace.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,20 @@ def main():
380380
help="Indicate wthether use torch.compile to speed the model, necessary to realize compatibility",
381381
)
382382

383+
## instrumentation policy configs
384+
parser.add_argument(
385+
"--sampling-interval",
386+
type=int,
387+
default=None,
388+
help="Interval of steps to instrument (e.g., 10 for every 10th step).",
389+
)
390+
parser.add_argument(
391+
"--warm-up-steps",
392+
type=int,
393+
default=0,
394+
help="Number of initial steps to always instrument.",
395+
)
396+
383397
args = parser.parse_args()
384398

385399
# read the configuration file
@@ -508,6 +522,8 @@ def main():
508522
instr_descriptors=args.instr_descriptors,
509523
no_auto_var_instr=args.no_auto_var_instr,
510524
use_torch_compile=args.use_torch_compile,
525+
sampling_interval=args.sampling_interval,
526+
warm_up_steps=args.warm_up_steps,
511527
)
512528

513529
if args.copy_all_files:

traincheck/config/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@
9595
TYPE_ERR_THRESHOLD = 3
9696
RECURSION_ERR_THRESHOLD = 5
9797

98+
INSTRUMENTATION_POLICY = {
99+
"interval": 1,
100+
"warm_up": 1, # default to 1 to ensure the first step is always instrumented: before warm-up is depleted, we do instrumentation with interval=1, after warm-up is depleted, we do instrumentation with the specified interval
101+
}
102+
103+
DISABLE_WRAPPER = False
104+
98105

99106
class InstrOpt:
100107
def __init__(

traincheck/developer/annotations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import traincheck.config.config as config
12
import traincheck.instrumentor.tracer as tracer
23
from traincheck.config.config import ALL_STAGE_NAMES
34
from traincheck.instrumentor import META_VARS
@@ -16,8 +17,13 @@ def annotate_stage(stage_name: str):
1617
stage_name in ALL_STAGE_NAMES
1718
), f"Invalid stage name: {stage_name}, valid ones are {ALL_STAGE_NAMES}"
1819

20+
old_stage = META_VARS.get("stage", None)
1921
META_VARS["stage"] = stage_name
2022

23+
# We always reset the wrapper when stage changes, and let the policy decide later if we should skip
24+
if old_stage != stage_name:
25+
config.DISABLE_WRAPPER = False
26+
2127

2228
def annotate_answer_start_token_ids(
2329
answer_start_token_id: int, include_start_token: bool = False

traincheck/instrumentor/control.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import logging
2+
3+
from traincheck.config import config
4+
from traincheck.instrumentor.caches import META_VARS
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def start_step():
10+
"""
11+
Called at the start of a training iteration to control instrumentation policy.
12+
increments step count and sets config.DISABLE_WRAPPER based on policy.
13+
"""
14+
# Only control policy if we are in training stage.
15+
# If explicit stage annotation is used, respect it.
16+
# If not tracking stage (or stage is None), we assume training if this is called?
17+
# Better to be safe and check if specific stage is set to non-training.
18+
stage = META_VARS.get("stage")
19+
if stage and stage != "training":
20+
# If explicitly in a non-training stage (e.g. evaluation),
21+
# we might want to disable wrapping?
22+
# Or just do nothing and let other logic handle it?
23+
# The user's request specificially mentioned alignment with training steps.
24+
# If we are in evaluation loop, we probably shouldn't be incrementing "step" or applying sampling policy intended for training.
25+
return
26+
27+
META_VARS["step"] += 1
28+
current_step = META_VARS["step"]
29+
30+
policy = config.INSTRUMENTATION_POLICY
31+
if policy:
32+
warm_up = policy.get("warm_up", 0)
33+
interval = policy.get("interval", 1)
34+
35+
# Default to enabled
36+
config.DISABLE_WRAPPER = False
37+
38+
if current_step < warm_up:
39+
print(f"Warmup step {current_step}")
40+
config.DISABLE_WRAPPER = False
41+
elif (current_step - warm_up) % interval == 0:
42+
print(f"Interval step {current_step}")
43+
config.DISABLE_WRAPPER = False
44+
else:
45+
print(f"Skipping step {current_step}")
46+
config.DISABLE_WRAPPER = True
47+
else:
48+
# No policy, always enable
49+
config.DISABLE_WRAPPER = False
50+
51+
52+
def start_eval_step():
53+
"""
54+
Called at the start of an evaluation iteration.
55+
Controls instrumentation policy using a separate step counter.
56+
"""
57+
if "eval_step" not in META_VARS:
58+
META_VARS["eval_step"] = 0
59+
60+
META_VARS["eval_step"] += 1
61+
current_step = META_VARS["eval_step"]
62+
63+
policy = config.INSTRUMENTATION_POLICY
64+
if policy:
65+
warm_up = policy.get("warm_up", 0)
66+
interval = policy.get("interval", 1)
67+
68+
config.DISABLE_WRAPPER = False
69+
70+
if current_step < warm_up:
71+
print(f"Eval: Warmup step {current_step}")
72+
config.DISABLE_WRAPPER = False
73+
elif (current_step - warm_up) % interval == 0:
74+
print(f"Eval: Interval step {current_step}")
75+
config.DISABLE_WRAPPER = False
76+
else:
77+
print(f"Eval: Skipping step {current_step}")
78+
config.DISABLE_WRAPPER = True
79+
else:
80+
config.DISABLE_WRAPPER = False

traincheck/instrumentor/proxy_wrapper/proxy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010

11+
import traincheck.config.config as config
1112
import traincheck.instrumentor.proxy_wrapper.proxy_config as proxy_config # HACK: cannot directly import config variables as then they would be local variables
1213
import traincheck.instrumentor.proxy_wrapper.proxy_methods as proxy_methods
1314
from traincheck.config.config import should_disable_proxy_dumping
@@ -158,6 +159,9 @@ def __deepcopy__(self, memo):
158159
return new_copy
159160

160161
def dump_trace(self, phase, dump_loc):
162+
if config.DISABLE_WRAPPER:
163+
return
164+
161165
obj = self._obj
162166
var_name = self.__dict__["var_name"]
163167
assert var_name is not None # '' is allowed as a var_name (root object)

traincheck/instrumentor/proxy_wrapper/proxy_observer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22
import typing
33

4+
import traincheck.config.config as config
45
from traincheck.config.config import should_disable_proxy_dumping
56
from traincheck.instrumentor.proxy_wrapper.subclass import ProxyParameter
67
from traincheck.utils import typename
@@ -21,6 +22,8 @@ def observe_proxy_var(
2122
phase,
2223
observe_api_name: str,
2324
):
25+
if config.DISABLE_WRAPPER:
26+
return
2427

2528
# update the proxy object's timestamp
2629
var.update_timestamp()

traincheck/instrumentor/proxy_wrapper/subclass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from torch import nn
88

9+
import traincheck.config.config as config
910
from traincheck.config.config import should_disable_proxy_dumping
1011
from traincheck.instrumentor.dumper import dump_trace_VAR
1112
from traincheck.instrumentor.proxy_wrapper.dumper import dump_attributes, get_meta_vars
@@ -178,6 +179,9 @@ def register_object(self):
178179
)
179180

180181
def dump_trace(self, phase, dump_loc):
182+
if config.DISABLE_WRAPPER:
183+
return
184+
181185
# TODO
182186
var_name = self.__dict__["var_name"]
183187
# assert var_name is not None # '' is allowed as a var_name (root object)

traincheck/instrumentor/source_file.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def __init__(
3333
use_full_instr: bool,
3434
funcs_to_instr: list[str] | None,
3535
API_dump_stack_trace: bool,
36+
sampling_interval: int,
37+
warm_up_steps: int,
3638
):
3739
super().__init__()
3840
if not modules_to_instr:
@@ -44,10 +46,27 @@ def __init__(
4446
self.use_full_instr = use_full_instr
4547
self.funcs_to_instr = funcs_to_instr
4648
self.API_dump_stack_trace = API_dump_stack_trace
49+
self.sampling_interval = sampling_interval
50+
self.warm_up_steps = warm_up_steps
51+
self.current_function = None
52+
53+
def visit_FunctionDef(self, node):
54+
old_function = self.current_function
55+
self.current_function = node.name
56+
self.generic_visit(node)
57+
self.current_function = old_function
58+
return node
59+
60+
def visit_AsyncFunctionDef(self, node):
61+
old_function = self.current_function
62+
self.current_function = node.name
63+
self.generic_visit(node)
64+
self.current_function = old_function
65+
return node
4766

4867
def get_instrument_node(self, module_name: str):
4968
return ast.parse(
50-
f"from traincheck.instrumentor.tracer import Instrumentor; Instrumentor({module_name}, scan_proxy_in_args={self.scan_proxy_in_args}, use_full_instr={self.use_full_instr}, funcs_to_instr={str(self.funcs_to_instr)}, API_dump_stack_trace={self.API_dump_stack_trace}).instrument()"
69+
f"from traincheck.instrumentor.tracer import Instrumentor; Instrumentor({module_name}, scan_proxy_in_args={self.scan_proxy_in_args}, use_full_instr={self.use_full_instr}, funcs_to_instr={str(self.funcs_to_instr)}, API_dump_stack_trace={self.API_dump_stack_trace}, sampling_interval={str(self.sampling_interval)}, warm_up_steps={str(self.warm_up_steps)}).instrument()"
5170
).body
5271

5372
def visit_Import(self, node):
@@ -65,8 +84,6 @@ def visit_Import(self, node):
6584
instrument_nodes.append(self.get_instrument_node(n.asname))
6685
else:
6786
instrument_nodes.append(self.get_instrument_node(n.name))
68-
# let's see if there are aliases, if yes, use them
69-
# if not, let's use the module name directly
7087
return [node] + instrument_nodes
7188

7289
def visit_ImportFrom(self, node):
@@ -87,6 +104,105 @@ def visit_ImportFrom(self, node):
87104
instrument_nodes.append(self.get_instrument_node(n.name))
88105
return [node] + instrument_nodes
89106

107+
def _get_loop_context(self, node):
108+
# Heuristic: Inject into loops that look like training loops.
109+
# Check for calls to .step() or .backward()
110+
has_training_signal = False
111+
for child in ast.walk(node):
112+
if isinstance(child, ast.Call):
113+
if isinstance(child.func, ast.Attribute):
114+
if child.func.attr in ["step", "backward"]:
115+
has_training_signal = True
116+
117+
if has_training_signal:
118+
return "training"
119+
120+
# If no explicit training signal, check if we are in an eval/test function
121+
if self.current_function:
122+
name_lower = self.current_function.lower()
123+
if "test" in name_lower or "eval" in name_lower or "valid" in name_lower:
124+
return "eval"
125+
126+
return None
127+
128+
def _inject_call(self, node, func_name):
129+
import_stmt = ast.ImportFrom(
130+
module="traincheck.instrumentor.control",
131+
names=[ast.alias(name=func_name, asname=None)],
132+
level=0,
133+
)
134+
call_stmt = ast.Expr(
135+
value=ast.Call(
136+
func=ast.Name(id=func_name, ctx=ast.Load()), args=[], keywords=[]
137+
)
138+
)
139+
node.body.insert(0, call_stmt)
140+
node.body.insert(0, import_stmt)
141+
return node
142+
143+
def visit_For(self, node):
144+
self.generic_visit(node)
145+
context = self._get_loop_context(node)
146+
if context == "training":
147+
return self._inject_call(node, "start_step")
148+
elif context == "eval":
149+
return self._inject_call(node, "start_eval_step")
150+
return node
151+
152+
def visit_While(self, node):
153+
self.generic_visit(node)
154+
context = self._get_loop_context(node)
155+
if context == "training":
156+
return self._inject_call(node, "start_step")
157+
elif context == "eval":
158+
return self._inject_call(node, "start_eval_step")
159+
return node
160+
161+
def _should_inject_control(self, node):
162+
# Heuristic: Inject into loops that look like training loops.
163+
# Check for calls to .step() or .backward()
164+
for child in ast.walk(node):
165+
if isinstance(child, ast.Call):
166+
if isinstance(child.func, ast.Attribute):
167+
if child.func.attr in ["step", "backward"]:
168+
return True
169+
return False
170+
171+
def _inject_start_step(self, node):
172+
import_stmt = ast.ImportFrom(
173+
module="traincheck.instrumentor.control",
174+
names=[ast.alias(name="start_step", asname=None)],
175+
level=0,
176+
)
177+
call_stmt = ast.Expr(
178+
value=ast.Call(
179+
func=ast.Name(id="start_step", ctx=ast.Load()), args=[], keywords=[]
180+
)
181+
)
182+
# We need to insert the import at the top of the file ideally,
183+
# but inserting inside the loop works if we deal with python scoping (imports are valid statements).
184+
# Actually proper way is to add import at module level.
185+
# But `visit_Module` is not here.
186+
# For simplicity, let's just use fully qualified name or inject import in the loop (a bit inefficient but works).
187+
# Better: Inject `import traincheck.instrumentor.control` at top of loop or use `traincheck.instrumentor.control.start_step()` with import logic handled elsewhere?
188+
# The `InsertTracerVisitor` modifies the module. We can add an import to the module body if we had access.
189+
# `visit_Import` adds imports.
190+
# Let's assume `traincheck` is importable.
191+
192+
# Helper to create `traincheck.instrumentor.control.start_step()` call
193+
# And ensure import is present.
194+
# Actually `InsertTracerVisitor` is used on the whole file.
195+
# Let's just blindly insert the call logic and rely on the fact that we can insert an import at the top of the loop
196+
# or just assume the user code can handle it if we inject the import statement right before the call.
197+
198+
# Let's inject:
199+
# from traincheck.instrumentor.control import start_step
200+
# start_step()
201+
202+
node.body.insert(0, call_stmt)
203+
node.body.insert(0, import_stmt)
204+
return node
205+
90206

91207
def instrument_library(
92208
source: str,
@@ -95,6 +211,8 @@ def instrument_library(
95211
use_full_instr: bool,
96212
funcs_to_instr: list[str] | None,
97213
API_dump_stack_trace: bool,
214+
sampling_interval: int,
215+
warm_up_steps: int,
98216
) -> str:
99217
"""
100218
Instruments the given source code and returns the instrumented source code.
@@ -116,6 +234,8 @@ def instrument_library(
116234
use_full_instr,
117235
funcs_to_instr,
118236
API_dump_stack_trace,
237+
sampling_interval,
238+
warm_up_steps,
119239
)
120240
root = visitor.visit(root)
121241
source = ast.unparse(root)
@@ -811,6 +931,8 @@ def instrument_file(
811931
instr_descriptors: bool,
812932
no_auto_var_instr: bool,
813933
use_torch_compile: bool,
934+
sampling_interval: int = 1,
935+
warm_up_steps: int = 0,
814936
) -> str:
815937
"""
816938
Instruments the given file and returns the instrumented source code.
@@ -827,6 +949,8 @@ def instrument_file(
827949
use_full_instr,
828950
funcs_to_instr,
829951
API_dump_stack_trace,
952+
sampling_interval,
953+
warm_up_steps,
830954
)
831955
# annotate stages
832956
instrumented_source = annotate_stage(instrumented_source)

0 commit comments

Comments
 (0)