1616PROJECT_ROOT = Path (__file__ ).resolve ().parent .parent
1717BENCH_DIR = PROJECT_ROOT / "benchmarks"
1818DEFAULT_OUTPUT = PROJECT_ROOT / "results-python.json"
19+ DEFAULT_MODULE_NAME_PREFIX = "cuda_bindings_bench"
1920# Env var used to propagate the --benchmark filter from the parent to pyperf
2021# worker subprocesses. pyperf reconstructs worker argv from scratch and drops
2122# custom flags like --benchmark, so without this the worker would register the
2223# full bench list and pyperf would run the wrong bench by task index.
23- BENCH_FILTER_ENV_VAR = "CUDA_BINDINGS_BENCH_FILTER"
24+ DEFAULT_BENCH_FILTER_ENV_VAR = "CUDA_BINDINGS_BENCH_FILTER"
2425
25- PYPERF_INHERITED_ENV_VARS = (
26+ BASE_PYPERF_INHERITED_ENV_VARS = (
2627 "CUDA_HOME" ,
2728 "CUDA_PATH" ,
2829 "CUDA_VISIBLE_DEVICES" ,
2930 "LD_LIBRARY_PATH" ,
3031 "NVIDIA_VISIBLE_DEVICES" ,
31- BENCH_FILTER_ENV_VAR ,
3232)
3333_MODULE_CACHE : dict [Path , ModuleType ] = {}
3434
3535
36- def load_module (module_path : Path ) -> ModuleType :
36+ def load_module (module_path : Path , module_name_prefix : str = DEFAULT_MODULE_NAME_PREFIX ) -> ModuleType :
3737 module_path = module_path .resolve ()
3838 cached_module = _MODULE_CACHE .get (module_path )
3939 if cached_module is not None :
4040 return cached_module
4141
42- module_name = f"cuda_bindings_bench_ { module_path .stem } "
42+ module_name = f"{ module_name_prefix } _ { module_path .stem } "
4343 spec = importlib .util .spec_from_file_location (module_name , module_path )
4444 if spec is None or spec .loader is None :
4545 raise RuntimeError (f"Failed to load benchmark module: { module_path } " )
@@ -64,13 +64,17 @@ def _discover_module_functions(module_path: Path) -> list[str]:
6464 ]
6565
6666
67- def _lazy_benchmark (module_path : Path , function_name : str ) -> Callable [[int ], float ]:
67+ def _lazy_benchmark (
68+ module_path : Path ,
69+ function_name : str ,
70+ module_name_prefix : str = DEFAULT_MODULE_NAME_PREFIX ,
71+ ) -> Callable [[int ], float ]:
6872 loaded_function : Callable [[int ], float ] | None = None
6973
7074 def run (loops : int ) -> float :
7175 nonlocal loaded_function
7276 if loaded_function is None :
73- module = load_module (module_path )
77+ module = load_module (module_path , module_name_prefix = module_name_prefix )
7478 loaded_function = getattr (module , function_name )
7579 return loaded_function (loops )
7680
@@ -86,6 +90,7 @@ def run(loops: int) -> float:
8690def _collect_skipped_benchmarks (
8791 bench_ids : list [str ],
8892 registry : dict [str , Callable [[int ], float ]],
93+ module_name_prefix : str = DEFAULT_MODULE_NAME_PREFIX ,
8994) -> set [str ]:
9095 """Return bench IDs that the owning module has marked as unsupported.
9196
@@ -106,29 +111,37 @@ def _collect_skipped_benchmarks(
106111 continue
107112 module = loaded_modules .get (module_path )
108113 if module is None :
109- module = load_module (module_path )
114+ module = load_module (module_path , module_name_prefix = module_name_prefix )
110115 loaded_modules [module_path ] = module
111116 module_skip = getattr (module , "SKIPPED_BENCHMARKS" , None )
112117 if module_skip and function_name in module_skip :
113118 skipped .add (bench_id )
114119 return skipped
115120
116121
117- def discover_benchmarks () -> dict [str , Callable [[int ], float ]]:
122+ def discover_benchmarks (
123+ bench_dir : Path | None = None ,
124+ module_name_prefix : str = DEFAULT_MODULE_NAME_PREFIX ,
125+ ) -> dict [str , Callable [[int ], float ]]:
118126 """Discover bench_ functions.
119127
120128 Each bench_ function must have the signature: bench_*(loops: int) -> float
121129 where it calls the operation `loops` times and returns the total elapsed
122130 time in seconds (using time.perf_counter).
123131 """
132+ # Resolve the default inside the call so tests (and embedders) can
133+ # monkeypatch ``BENCH_DIR`` at the module level — Python binds default
134+ # args at def-time, so a literal default would ignore later patches.
135+ if bench_dir is None :
136+ bench_dir = BENCH_DIR
124137 registry : dict [str , Callable [[int ], float ]] = {}
125- for module_path in sorted (BENCH_DIR .glob ("bench_*.py" )):
138+ for module_path in sorted (bench_dir .glob ("bench_*.py" )):
126139 module_name = module_path .stem
127140 for function_name in _discover_module_functions (module_path ):
128141 bench_id = benchmark_id (module_name , function_name )
129142 if bench_id in registry :
130143 raise ValueError (f"Duplicate benchmark ID discovered: { bench_id } " )
131- registry [bench_id ] = _lazy_benchmark (module_path , function_name )
144+ registry [bench_id ] = _lazy_benchmark (module_path , function_name , module_name_prefix = module_name_prefix )
132145 return registry
133146
134147
@@ -152,7 +165,10 @@ def _split_env_vars(arg_value: str) -> list[str]:
152165 return [env_var for env_var in arg_value .split ("," ) if env_var ]
153166
154167
155- def ensure_pyperf_worker_env (argv : list [str ]) -> list [str ]:
168+ def ensure_pyperf_worker_env (
169+ argv : list [str ],
170+ extra_env_vars : tuple [str , ...] = (DEFAULT_BENCH_FILTER_ENV_VAR ,),
171+ ) -> list [str ]:
156172 if "--copy-env" in argv :
157173 return list (argv )
158174
@@ -175,7 +191,7 @@ def ensure_pyperf_worker_env(argv: list[str]) -> list[str]:
175191 if skip_next :
176192 raise ValueError ("Missing value for --inherit-environ" )
177193
178- for env_var in PYPERF_INHERITED_ENV_VARS :
194+ for env_var in ( * BASE_PYPERF_INHERITED_ENV_VARS , * extra_env_vars ) :
179195 if env_var in os .environ :
180196 inherited_env .append (env_var )
181197
@@ -190,7 +206,7 @@ def ensure_pyperf_worker_env(argv: list[str]) -> list[str]:
190206 return cleaned
191207
192208
193- def parse_args (argv : list [str ]) -> tuple [argparse .Namespace , list [str ]]:
209+ def parse_args (argv : list [str ], default_output : Path = DEFAULT_OUTPUT ) -> tuple [argparse .Namespace , list [str ]]:
194210 parser = argparse .ArgumentParser (add_help = False )
195211 parser .add_argument (
196212 "--benchmark" ,
@@ -207,19 +223,25 @@ def parse_args(argv: list[str]) -> tuple[argparse.Namespace, list[str]]:
207223 "-o" ,
208224 "--output" ,
209225 type = Path ,
210- default = DEFAULT_OUTPUT ,
211- help = f"JSON output file path (default: { DEFAULT_OUTPUT .name } )" ,
226+ default = default_output ,
227+ help = f"JSON output file path (default: { default_output .name } )" ,
212228 )
213229 parsed , remaining = parser .parse_known_args (argv )
214230 return parsed , remaining
215231
216232
217- def main () -> None :
218- parsed , remaining_argv = parse_args (sys .argv [1 :])
233+ def main (
234+ * ,
235+ bench_dir : Path = BENCH_DIR ,
236+ default_output : Path = DEFAULT_OUTPUT ,
237+ module_name_prefix : str = DEFAULT_MODULE_NAME_PREFIX ,
238+ bench_filter_env_var : str = DEFAULT_BENCH_FILTER_ENV_VAR ,
239+ ) -> None :
240+ parsed , remaining_argv = parse_args (sys .argv [1 :], default_output = default_output )
219241
220- registry = discover_benchmarks ()
242+ registry = discover_benchmarks (bench_dir = bench_dir , module_name_prefix = module_name_prefix )
221243 if not registry :
222- raise RuntimeError (f"No benchmark functions found in { BENCH_DIR } " )
244+ raise RuntimeError (f"No benchmark functions found in { bench_dir } " )
223245
224246 if parsed .list :
225247 for bench_id in sorted (registry ):
@@ -231,7 +253,7 @@ def main() -> None:
231253 # the wrong bench. pyperf drops unknown CLI flags when spawning workers,
232254 # so fall back to an env var carrying the filter.
233255 requested = list (parsed .benchmark )
234- env_filter = os .environ .get (BENCH_FILTER_ENV_VAR , "" )
256+ env_filter = os .environ .get (bench_filter_env_var , "" )
235257 if not requested and env_filter :
236258 requested = [bid for bid in env_filter .split ("," ) if bid ]
237259
@@ -243,21 +265,21 @@ def main() -> None:
243265 raise ValueError (f"Unknown benchmark(s): { unknown } . Known benchmarks: { known } " )
244266 benchmark_ids = requested
245267 # Propagate to any pyperf worker we're about to spawn.
246- os .environ [BENCH_FILTER_ENV_VAR ] = "," .join (benchmark_ids )
268+ os .environ [bench_filter_env_var ] = "," .join (benchmark_ids )
247269 else :
248270 benchmark_ids = sorted (registry )
249271
250272 # Strip any --output args to avoid conflicts with our output handling.
251273 output_path = parsed .output .resolve ()
252274 remaining_argv = strip_pyperf_output_args (remaining_argv )
253- remaining_argv = ensure_pyperf_worker_env (remaining_argv )
275+ remaining_argv = ensure_pyperf_worker_env (remaining_argv , extra_env_vars = ( bench_filter_env_var ,) )
254276 is_worker = "--worker" in remaining_argv
255277
256278 # Drop benchmarks that the owning module has marked as unavailable on
257279 # this driver/device. Without this step a single unsupported bench
258280 # (e.g. TMA on a pre-Hopper GPU) would abort the whole pyperf run,
259281 # since pyperf treats a raised exception as a fatal worker failure.
260- skipped = _collect_skipped_benchmarks (benchmark_ids , registry )
282+ skipped = _collect_skipped_benchmarks (benchmark_ids , registry , module_name_prefix = module_name_prefix )
261283 if skipped and not is_worker :
262284 for bench_id in sorted (skipped ):
263285 print (f"Skipping { bench_id } : unsupported on this driver/device" , file = sys .stderr )
0 commit comments