Skip to content

Commit 394081e

Browse files
feat(closes OPEN-10341): add native async runner for testset batches
1 parent 52a313b commit 394081e

1 file changed

Lines changed: 136 additions & 42 deletions

File tree

src/openlayer/lib/core/base_model.py

Lines changed: 136 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import abc
55
import json
66
import time
7+
import asyncio
78
import inspect
89
import argparse
9-
from typing import Any, Dict, Tuple
10+
from typing import Any, Dict, List, Tuple, Optional
1011
from dataclasses import field, dataclass
1112

1213
import pandas as pd
@@ -36,6 +37,12 @@ class OpenlayerModel(abc.ABC):
3637
3738
It is more conventional to implement the `run` method.
3839
40+
``run`` may be defined as either ``def run`` (called sequentially per row)
41+
or ``async def run``. When ``run`` is async, ``run_batch_from_df`` will drive
42+
rows concurrently with ``asyncio.gather``; pass ``max_workers > 1`` to enable
43+
concurrent execution. Use async-native I/O (``httpx``, ``openai-async``, etc.)
44+
inside an async ``run`` to actually benefit from concurrency.
45+
3946
Refer to Openlayer's templates for examples of how to implement this class.
4047
"""
4148

@@ -59,6 +66,15 @@ def run_from_cli(self) -> None:
5966
required=False,
6067
help="Custom arguments in format 'key1=value1,key2=value2'",
6168
)
69+
parser.add_argument(
70+
"--max-workers",
71+
type=int,
72+
default=None,
73+
help=(
74+
"Max concurrent rows when run() is async. "
75+
"Defaults to 4 for async run, 1 for sync run."
76+
),
77+
)
6278

6379
# Parse the arguments
6480
args = parser.parse_args()
@@ -76,9 +92,12 @@ def run_from_cli(self) -> None:
7692
return self.batch(
7793
dataset_path=args.dataset_path,
7894
output_dir=args.output_dir,
95+
max_workers=args.max_workers,
7996
)
8097

81-
def batch(self, dataset_path: str, output_dir: str) -> None:
98+
def batch(
99+
self, dataset_path: str, output_dir: str, max_workers: Optional[int] = None
100+
) -> None:
82101
"""Reads the dataset from a file and runs the model on it."""
83102
# Load the dataset into a pandas DataFrame
84103
fmt = "csv"
@@ -91,50 +110,125 @@ def batch(self, dataset_path: str, output_dir: str) -> None:
91110
raise ValueError(f"Unsupported dataset format: {dataset_path}")
92111

93112
# Call the model's run_batch method, passing in the DataFrame
94-
output_df, config = self.run_batch_from_df(df)
113+
output_df, config = self.run_batch_from_df(df, max_workers=max_workers)
95114
self.write_output_to_directory(output_df, config, output_dir, fmt)
96115

97-
def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
98-
"""Function that runs the model and returns the result."""
99-
# Ensure the 'output' column exists
100-
if "output" not in df.columns:
101-
df["output"] = None
116+
def run_batch_from_df(
117+
self, df: pd.DataFrame, max_workers: Optional[int] = None
118+
) -> Tuple[pd.DataFrame, dict]:
119+
"""Function that runs the model and returns the result.
102120
103-
# Get the signature of the 'run' method
121+
If ``run`` is defined as ``async def run(...)``, rows are dispatched
122+
concurrently with ``asyncio.gather`` gated by ``asyncio.Semaphore(max_workers)``.
123+
``max_workers`` defaults to 4 for an async ``run`` (writing `async def`
124+
is the opt-in signal that interleaving is safe). For a synchronous
125+
``run``, rows are processed sequentially and ``max_workers`` must be 1.
126+
127+
A row's exception propagates and aborts the batch. For the async path,
128+
``asyncio.gather`` cancels in-flight siblings before re-raising.
129+
"""
104130
run_signature = inspect.signature(self.run)
131+
valid_params = set(run_signature.parameters)
132+
is_async = inspect.iscoroutinefunction(self.run)
133+
134+
if max_workers is None:
135+
max_workers = 4 if is_async else 1
136+
elif max_workers < 1:
137+
raise ValueError("max_workers must be >= 1")
138+
139+
if max_workers > 1 and not is_async:
140+
raise ValueError(
141+
"max_workers > 1 requires an async `run` method. "
142+
"Define `run` as `async def run(self, ...)` to enable "
143+
"concurrent execution."
144+
)
145+
146+
for col in ("output", "steps", "latency", "cost", "tokens", "context"):
147+
if col not in df.columns:
148+
df[col] = None
149+
150+
rows = [
151+
(
152+
idx,
153+
{k: v for k, v in row.to_dict().items() if k in valid_params},
154+
)
155+
for idx, row in df.iterrows()
156+
]
157+
158+
if is_async:
159+
try:
160+
asyncio.get_running_loop()
161+
except RuntimeError:
162+
pass
163+
else:
164+
raise RuntimeError(
165+
"run_batch_from_df was called from inside a running event "
166+
"loop. Call `await self._run_rows_async(...)` directly "
167+
"from async code."
168+
)
169+
results = asyncio.run(self._run_rows_async(rows, max_workers))
170+
else:
171+
results = [
172+
(idx, self.run(**kwargs), tracer.get_current_trace())
173+
for idx, kwargs in rows
174+
]
175+
176+
for index, output, trace in results:
177+
self._apply_row_result(df, index, output, trace)
105178

106-
for index, row in df.iterrows():
107-
# Filter row_dict to only include keys that are valid parameters
108-
# for the 'run' method
109-
row_dict = row.to_dict()
110-
filtered_kwargs = {
111-
k: v for k, v in row_dict.items() if k in run_signature.parameters
112-
}
113-
114-
# Call the run method with filtered kwargs
115-
output = self.run(**filtered_kwargs)
116-
117-
df.at[index, "output"] = output.output
118-
119-
for k, v in output.other_fields.items():
120-
if k not in df.columns:
121-
df[k] = None
122-
df.at[index, k] = v
123-
124-
trace = tracer.get_current_trace()
125-
if trace:
126-
processed_trace, _ = tracer.post_process_trace(trace_obj=trace)
127-
df.at[index, "steps"] = trace.to_dict()
128-
if "latency" in processed_trace:
129-
df.at[index, "latency"] = processed_trace["latency"]
130-
if "cost" in processed_trace:
131-
df.at[index, "cost"] = processed_trace["cost"]
132-
if "tokens" in processed_trace:
133-
df.at[index, "tokens"] = processed_trace["tokens"]
134-
if "context" in processed_trace:
135-
df.at[index, "context"] = processed_trace["context"]
136-
137-
config = {
179+
return df, self._build_config(run_signature, df)
180+
181+
async def _run_rows_async(
182+
self,
183+
rows: List[Tuple[Any, Dict[str, Any]]],
184+
max_workers: int,
185+
) -> List[Tuple[Any, RunReturn, Optional[Any]]]:
186+
"""Drive an async ``run`` over all rows with bounded concurrency.
187+
188+
The first row to raise causes ``asyncio.gather`` to cancel in-flight
189+
siblings and re-raise the original exception.
190+
"""
191+
sem = asyncio.Semaphore(max_workers)
192+
193+
async def _one(index: Any, kwargs: Dict[str, Any]):
194+
async with sem:
195+
output = await self.run(**kwargs)
196+
return index, output, tracer.get_current_trace()
197+
198+
return await asyncio.gather(*(_one(i, k) for i, k in rows))
199+
200+
def _apply_row_result(
201+
self,
202+
df: pd.DataFrame,
203+
index: Any,
204+
output: RunReturn,
205+
trace: Optional[Any],
206+
) -> None:
207+
"""Write a single row's output and trace fields into ``df`` in place."""
208+
df.at[index, "output"] = output.output
209+
210+
for k, v in output.other_fields.items():
211+
if k not in df.columns:
212+
df[k] = None
213+
df.at[index, k] = v
214+
215+
if trace:
216+
processed_trace, _ = tracer.post_process_trace(trace_obj=trace)
217+
df.at[index, "steps"] = trace.to_dict()
218+
if "latency" in processed_trace:
219+
df.at[index, "latency"] = processed_trace["latency"]
220+
if "cost" in processed_trace:
221+
df.at[index, "cost"] = processed_trace["cost"]
222+
if "tokens" in processed_trace:
223+
df.at[index, "tokens"] = processed_trace["tokens"]
224+
if "context" in processed_trace:
225+
df.at[index, "context"] = processed_trace["context"]
226+
227+
def _build_config(
228+
self, run_signature: inspect.Signature, df: pd.DataFrame
229+
) -> Dict[str, Any]:
230+
"""Build the config dict returned alongside the output DataFrame."""
231+
config: Dict[str, Any] = {
138232
"outputColumnName": "output",
139233
"inputVariableNames": list(run_signature.parameters.keys()),
140234
"metadata": {
@@ -154,7 +248,7 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
154248
for k, v in self.custom_args.items():
155249
config["metadata"][k] = v
156250

157-
return df, config
251+
return config
158252

159253
def write_output_to_directory(
160254
self,

0 commit comments

Comments
 (0)