From 3d7cef5eaf3417976b130d382224a3346d81e9a2 Mon Sep 17 00:00:00 2001 From: yuj Date: Fri, 24 Apr 2026 16:33:12 +0800 Subject: [PATCH] fix: correct logic bug, f-string typo, and bare except clauses - fix model_type comparison: `== ['MNLI', 'QQP']` always evaluates to False because it compares a string to a list; use `in` instead - fix f-string format: `f{out_pr}` prints literal 'f' prefix instead of the variable value; remove the stray 'f' - replace bare `except:` with `except Exception:` in 3 files to avoid catching SystemExit and KeyboardInterrupt Co-Authored-By: Claude Opus 4.7 --- megatron/fused_kernels/tests/test_fused_kernels.py | 2 +- megatron/initialize.py | 2 +- megatron/testing_utils.py | 4 ++-- tools/merge_mp_partitions.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/megatron/fused_kernels/tests/test_fused_kernels.py b/megatron/fused_kernels/tests/test_fused_kernels.py index f8d5027a1..693d7245a 100644 --- a/megatron/fused_kernels/tests/test_fused_kernels.py +++ b/megatron/fused_kernels/tests/test_fused_kernels.py @@ -290,7 +290,7 @@ def test_layer_norm(): transformers.logging.FATAL, ) - except: + except Exception: print("\n[Fail] Please install `transformers` package to test fused kernels\n") exit(-1) diff --git a/megatron/initialize.py b/megatron/initialize.py index b967a228a..346347606 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -121,7 +121,7 @@ def set_verbosity_transformers(logging_level: str): from transformers.utils import logging as transformers_logging log_level = logging.log_levels[logging_level] logging.set_verbosity(log_level) - except: + except Exception: pass if args.rank == 0: diff --git a/megatron/testing_utils.py b/megatron/testing_utils.py index 9521cb361..bcf53c5a1 100644 --- a/megatron/testing_utils.py +++ b/megatron/testing_utils.py @@ -37,7 +37,7 @@ try: import torch _torch_available = True -except: +except Exception: _torch_available = False def is_torch_available(): @@ -312,7 +312,7 @@ def apply_print_resets(buf): def assert_screenout(out, what): out_pr = apply_print_resets(out).lower() match_str = out_pr.find(what.lower()) - assert match_str != -1, f"expecting to find {what} in output: f{out_pr}" + assert match_str != -1, f"expecting to find {what} in output: {out_pr}" class CaptureStd: diff --git a/tools/merge_mp_partitions.py b/tools/merge_mp_partitions.py index 4dc2d99f8..ca78842fc 100644 --- a/tools/merge_mp_partitions.py +++ b/tools/merge_mp_partitions.py @@ -115,7 +115,7 @@ def get_model(model_type): from pretrain_gpt import model_provider elif model_type == 'RACE': from tasks.race.finetune import model_provider - elif model_type == ['MNLI', 'QQP']: + elif model_type in ['MNLI', 'QQP']: num_classes = 2 if model_type == 'MNLI': num_classes = 3