1- import logging
2- import os
31import time
42from dataclasses import dataclass
53
119from infinilm .lib import _infinilm
1210
1311from .modeling_utils import parse_dtype
14-
15-
16- logger = logging .getLogger (__name__ )
17-
18- def _iter_exception_chain (e : BaseException , * , max_depth : int = 6 ):
19- cur : BaseException | None = e
20- depth = 0
21- seen : set [int ] = set ()
22- while cur is not None and depth < max_depth :
23- cur_id = id (cur )
24- if cur_id in seen :
25- break
26- seen .add (cur_id )
27- yield cur
28- depth += 1
29- cur = cur .__cause__ or cur .__context__
30-
31-
32- def is_oom_exception (e : BaseException ) -> bool :
33- """
34- Conservative OOM detector for MetaX allocator failures and CUDA/PyTorch OOMs.
35- Checks exception type (when available) and message substrings across chained exceptions.
36- """
37- # PyTorch OOM exception type (only if torch is present in this environment)
38- try :
39- import torch # type: ignore
40-
41- oom_type = getattr (torch , "OutOfMemoryError" , None )
42- if oom_type is not None :
43- for ex in _iter_exception_chain (e ):
44- if isinstance (ex , oom_type ):
45- return True
46- except Exception :
47- pass
48-
49- # Common patterns observed for allocator failures.
50- # Keep this allowlist small to avoid hard-exiting on unrelated errors.
51- patterns = (
52- # MetaX / infinirt allocator
53- "hcmalloc" ,
54- "infinirtmalloc" ,
55- "out of memory" ,
56- # CUDA / driver / runtime alloc failures
57- "cuda out of memory" ,
58- "cumemalloc" ,
59- "cublas_status_alloc_failed" ,
60- "cudnn_status_alloc_failed" ,
61- )
62-
63- for ex in _iter_exception_chain (e ):
64- msg = str (ex )
65- if not msg :
66- continue
67- msg_l = msg .lower ()
68- if any (p in msg_l for p in patterns ):
69- return True
70- return False
12+ from .exception_utils import handle_oom_and_exit
7113
7214
7315@dataclass
@@ -105,9 +47,11 @@ def __init__(
10547 cache_config ,
10648 enable_graph_compiling ,
10749 attention_backend ,
108- parse_dtype (kv_cache_dtype )._underlying
109- if kv_cache_dtype is not None
110- else None ,
50+ (
51+ parse_dtype (kv_cache_dtype )._underlying
52+ if kv_cache_dtype is not None
53+ else None
54+ ),
11155 )
11256 self .use_cache = False
11357
@@ -134,7 +78,9 @@ def forward(
13478 try :
13579 # TODO: Remove `_underlying` and simplify the corresponding code.
13680 input_ids = input_ids ._underlying if input_ids is not None else None
137- position_ids = position_ids ._underlying if position_ids is not None else None
81+ position_ids = (
82+ position_ids ._underlying if position_ids is not None else None
83+ )
13884 past_kv_lengths = (
13985 past_kv_lengths ._underlying if past_kv_lengths is not None else None
14086 )
@@ -172,13 +118,7 @@ def forward(
172118 .output_ids
173119 )
174120 except BaseException as e :
175- if is_oom_exception (e ):
176- logger .error (
177- "OOM-like exception: exiting worker with code 137: %r" ,
178- e ,
179- exc_info = False ,
180- )
181- os ._exit (137 )
121+ handle_oom_and_exit (e )
182122 raise
183123
184124 def generate (
0 commit comments