Skip to content

Commit 9802e2e

Browse files
committed
issue/282 - restructured oom exception handling
1 parent ef68739 commit 9802e2e

2 files changed

Lines changed: 85 additions & 70 deletions

File tree

python/infinilm/exception_utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import os
2+
import logging
3+
from typing import Iterator
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def _iter_exception_chain(
9+
e: BaseException, *, max_depth: int = 6
10+
) -> Iterator[BaseException]:
11+
"""Iterate through exception chain with depth limit."""
12+
cur: BaseException | None = e
13+
depth = 0
14+
seen: set[int] = set()
15+
while cur is not None and depth < max_depth:
16+
cur_id = id(cur)
17+
if cur_id in seen:
18+
break
19+
seen.add(cur_id)
20+
yield cur
21+
depth += 1
22+
cur = cur.__cause__ or cur.__context__
23+
24+
25+
def is_oom_exception(e: BaseException) -> bool:
26+
"""
27+
Conservative OOM detector for MetaX allocator failures and CUDA/PyTorch OOMs.
28+
Checks exception type (when available) and message substrings across chained exceptions.
29+
"""
30+
# PyTorch OOM exception type (only if torch is present in this environment)
31+
try:
32+
import torch # type: ignore
33+
34+
oom_type = getattr(torch, "OutOfMemoryError", None)
35+
if oom_type is not None:
36+
for ex in _iter_exception_chain(e):
37+
if isinstance(ex, oom_type):
38+
return True
39+
except Exception:
40+
pass
41+
42+
# Common patterns observed for allocator failures.
43+
# Keep this allowlist small to avoid hard-exiting on unrelated errors.
44+
patterns = (
45+
# MetaX / infinirt allocator
46+
"hcmalloc",
47+
"infinirtmalloc",
48+
"out of memory",
49+
# CUDA / driver / runtime alloc failures
50+
"cuda out of memory",
51+
"cumemalloc",
52+
"cublas_status_alloc_failed",
53+
"cudnn_status_alloc_failed",
54+
)
55+
56+
for ex in _iter_exception_chain(e):
57+
msg = str(ex)
58+
if not msg:
59+
continue
60+
msg_l = msg.lower()
61+
if any(p in msg_l for p in patterns):
62+
return True
63+
return False
64+
65+
66+
def handle_oom_and_exit(e: BaseException, exit_code: int = 137) -> None:
67+
"""Handle OOM exception by logging and exiting."""
68+
if is_oom_exception(e):
69+
logger.error(
70+
"OOM-like exception: exiting worker with code %d: %r",
71+
exit_code,
72+
e,
73+
exc_info=False,
74+
)
75+
os._exit(exit_code)

python/infinilm/infer_engine.py

Lines changed: 10 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import logging
2-
import os
31
import time
42
from dataclasses import dataclass
53

@@ -11,63 +9,7 @@
119
from infinilm.lib import _infinilm
1210

1311
from .modeling_utils import parse_dtype
14-
15-
16-
logger = logging.getLogger(__name__)
17-
18-
def _iter_exception_chain(e: BaseException, *, max_depth: int = 6):
19-
cur: BaseException | None = e
20-
depth = 0
21-
seen: set[int] = set()
22-
while cur is not None and depth < max_depth:
23-
cur_id = id(cur)
24-
if cur_id in seen:
25-
break
26-
seen.add(cur_id)
27-
yield cur
28-
depth += 1
29-
cur = cur.__cause__ or cur.__context__
30-
31-
32-
def is_oom_exception(e: BaseException) -> bool:
33-
"""
34-
Conservative OOM detector for MetaX allocator failures and CUDA/PyTorch OOMs.
35-
Checks exception type (when available) and message substrings across chained exceptions.
36-
"""
37-
# PyTorch OOM exception type (only if torch is present in this environment)
38-
try:
39-
import torch # type: ignore
40-
41-
oom_type = getattr(torch, "OutOfMemoryError", None)
42-
if oom_type is not None:
43-
for ex in _iter_exception_chain(e):
44-
if isinstance(ex, oom_type):
45-
return True
46-
except Exception:
47-
pass
48-
49-
# Common patterns observed for allocator failures.
50-
# Keep this allowlist small to avoid hard-exiting on unrelated errors.
51-
patterns = (
52-
# MetaX / infinirt allocator
53-
"hcmalloc",
54-
"infinirtmalloc",
55-
"out of memory",
56-
# CUDA / driver / runtime alloc failures
57-
"cuda out of memory",
58-
"cumemalloc",
59-
"cublas_status_alloc_failed",
60-
"cudnn_status_alloc_failed",
61-
)
62-
63-
for ex in _iter_exception_chain(e):
64-
msg = str(ex)
65-
if not msg:
66-
continue
67-
msg_l = msg.lower()
68-
if any(p in msg_l for p in patterns):
69-
return True
70-
return False
12+
from .exception_utils import handle_oom_and_exit
7113

7214

7315
@dataclass
@@ -105,9 +47,11 @@ def __init__(
10547
cache_config,
10648
enable_graph_compiling,
10749
attention_backend,
108-
parse_dtype(kv_cache_dtype)._underlying
109-
if kv_cache_dtype is not None
110-
else None,
50+
(
51+
parse_dtype(kv_cache_dtype)._underlying
52+
if kv_cache_dtype is not None
53+
else None
54+
),
11155
)
11256
self.use_cache = False
11357

@@ -134,7 +78,9 @@ def forward(
13478
try:
13579
# TODO: Remove `_underlying` and simplify the corresponding code.
13680
input_ids = input_ids._underlying if input_ids is not None else None
137-
position_ids = position_ids._underlying if position_ids is not None else None
81+
position_ids = (
82+
position_ids._underlying if position_ids is not None else None
83+
)
13884
past_kv_lengths = (
13985
past_kv_lengths._underlying if past_kv_lengths is not None else None
14086
)
@@ -172,13 +118,7 @@ def forward(
172118
.output_ids
173119
)
174120
except BaseException as e:
175-
if is_oom_exception(e):
176-
logger.error(
177-
"OOM-like exception: exiting worker with code 137: %r",
178-
e,
179-
exc_info=False,
180-
)
181-
os._exit(137)
121+
handle_oom_and_exit(e)
182122
raise
183123

184124
def generate(

0 commit comments

Comments
 (0)