44import abc
55import json
66import time
7+ import asyncio
78import inspect
89import argparse
9- from typing import Any , Dict , Tuple
10+ from typing import Any , Dict , List , Tuple , Optional
1011from dataclasses import field , dataclass
1112
1213import pandas as pd
@@ -36,6 +37,12 @@ class OpenlayerModel(abc.ABC):
3637
3738 It is more conventional to implement the `run` method.
3839
40+ ``run`` may be defined as either ``def run`` (called sequentially per row)
41+ or ``async def run``. When ``run`` is async, ``run_batch_from_df`` will drive
42+ rows concurrently with ``asyncio.gather``; pass ``max_workers > 1`` to enable
43+ concurrent execution. Use async-native I/O (``httpx``, ``openai-async``, etc.)
44+ inside an async ``run`` to actually benefit from concurrency.
45+
3946 Refer to Openlayer's templates for examples of how to implement this class.
4047 """
4148
@@ -59,6 +66,15 @@ def run_from_cli(self) -> None:
5966 required = False ,
6067 help = "Custom arguments in format 'key1=value1,key2=value2'" ,
6168 )
69+ parser .add_argument (
70+ "--max-workers" ,
71+ type = int ,
72+ default = None ,
73+ help = (
74+ "Max concurrent rows when run() is async. "
75+ "Defaults to 4 for async run, 1 for sync run."
76+ ),
77+ )
6278
6379 # Parse the arguments
6480 args = parser .parse_args ()
@@ -76,9 +92,12 @@ def run_from_cli(self) -> None:
7692 return self .batch (
7793 dataset_path = args .dataset_path ,
7894 output_dir = args .output_dir ,
95+ max_workers = args .max_workers ,
7996 )
8097
81- def batch (self , dataset_path : str , output_dir : str ) -> None :
98+ def batch (
99+ self , dataset_path : str , output_dir : str , max_workers : Optional [int ] = None
100+ ) -> None :
82101 """Reads the dataset from a file and runs the model on it."""
83102 # Load the dataset into a pandas DataFrame
84103 fmt = "csv"
@@ -91,50 +110,125 @@ def batch(self, dataset_path: str, output_dir: str) -> None:
91110 raise ValueError (f"Unsupported dataset format: { dataset_path } " )
92111
93112 # Call the model's run_batch method, passing in the DataFrame
94- output_df , config = self .run_batch_from_df (df )
113+ output_df , config = self .run_batch_from_df (df , max_workers = max_workers )
95114 self .write_output_to_directory (output_df , config , output_dir , fmt )
96115
97- def run_batch_from_df (self , df : pd .DataFrame ) -> Tuple [pd .DataFrame , dict ]:
98- """Function that runs the model and returns the result."""
99- # Ensure the 'output' column exists
100- if "output" not in df .columns :
101- df ["output" ] = None
116+ def run_batch_from_df (
117+ self , df : pd .DataFrame , max_workers : Optional [int ] = None
118+ ) -> Tuple [pd .DataFrame , dict ]:
119+ """Function that runs the model and returns the result.
102120
103- # Get the signature of the 'run' method
121+ If ``run`` is defined as ``async def run(...)``, rows are dispatched
122+ concurrently with ``asyncio.gather`` gated by ``asyncio.Semaphore(max_workers)``.
123+ ``max_workers`` defaults to 4 for an async ``run`` (writing `async def`
124+ is the opt-in signal that interleaving is safe). For a synchronous
125+ ``run``, rows are processed sequentially and ``max_workers`` must be 1.
126+
127+ A row's exception propagates and aborts the batch. For the async path,
128+ ``asyncio.gather`` cancels in-flight siblings before re-raising.
129+ """
104130 run_signature = inspect .signature (self .run )
131+ valid_params = set (run_signature .parameters )
132+ is_async = inspect .iscoroutinefunction (self .run )
133+
134+ if max_workers is None :
135+ max_workers = 4 if is_async else 1
136+ elif max_workers < 1 :
137+ raise ValueError ("max_workers must be >= 1" )
138+
139+ if max_workers > 1 and not is_async :
140+ raise ValueError (
141+ "max_workers > 1 requires an async `run` method. "
142+ "Define `run` as `async def run(self, ...)` to enable "
143+ "concurrent execution."
144+ )
145+
146+ for col in ("output" , "steps" , "latency" , "cost" , "tokens" , "context" ):
147+ if col not in df .columns :
148+ df [col ] = None
149+
150+ rows = [
151+ (
152+ idx ,
153+ {k : v for k , v in row .to_dict ().items () if k in valid_params },
154+ )
155+ for idx , row in df .iterrows ()
156+ ]
157+
158+ if is_async :
159+ try :
160+ asyncio .get_running_loop ()
161+ except RuntimeError :
162+ pass
163+ else :
164+ raise RuntimeError (
165+ "run_batch_from_df was called from inside a running event "
166+ "loop. Call `await self._run_rows_async(...)` directly "
167+ "from async code."
168+ )
169+ results = asyncio .run (self ._run_rows_async (rows , max_workers ))
170+ else :
171+ results = [
172+ (idx , self .run (** kwargs ), tracer .get_current_trace ())
173+ for idx , kwargs in rows
174+ ]
175+
176+ for index , output , trace in results :
177+ self ._apply_row_result (df , index , output , trace )
105178
106- for index , row in df .iterrows ():
107- # Filter row_dict to only include keys that are valid parameters
108- # for the 'run' method
109- row_dict = row .to_dict ()
110- filtered_kwargs = {
111- k : v for k , v in row_dict .items () if k in run_signature .parameters
112- }
113-
114- # Call the run method with filtered kwargs
115- output = self .run (** filtered_kwargs )
116-
117- df .at [index , "output" ] = output .output
118-
119- for k , v in output .other_fields .items ():
120- if k not in df .columns :
121- df [k ] = None
122- df .at [index , k ] = v
123-
124- trace = tracer .get_current_trace ()
125- if trace :
126- processed_trace , _ = tracer .post_process_trace (trace_obj = trace )
127- df .at [index , "steps" ] = trace .to_dict ()
128- if "latency" in processed_trace :
129- df .at [index , "latency" ] = processed_trace ["latency" ]
130- if "cost" in processed_trace :
131- df .at [index , "cost" ] = processed_trace ["cost" ]
132- if "tokens" in processed_trace :
133- df .at [index , "tokens" ] = processed_trace ["tokens" ]
134- if "context" in processed_trace :
135- df .at [index , "context" ] = processed_trace ["context" ]
136-
137- config = {
179+ return df , self ._build_config (run_signature , df )
180+
181+ async def _run_rows_async (
182+ self ,
183+ rows : List [Tuple [Any , Dict [str , Any ]]],
184+ max_workers : int ,
185+ ) -> List [Tuple [Any , RunReturn , Optional [Any ]]]:
186+ """Drive an async ``run`` over all rows with bounded concurrency.
187+
188+ The first row to raise causes ``asyncio.gather`` to cancel in-flight
189+ siblings and re-raise the original exception.
190+ """
191+ sem = asyncio .Semaphore (max_workers )
192+
193+ async def _one (index : Any , kwargs : Dict [str , Any ]):
194+ async with sem :
195+ output = await self .run (** kwargs )
196+ return index , output , tracer .get_current_trace ()
197+
198+ return await asyncio .gather (* (_one (i , k ) for i , k in rows ))
199+
200+ def _apply_row_result (
201+ self ,
202+ df : pd .DataFrame ,
203+ index : Any ,
204+ output : RunReturn ,
205+ trace : Optional [Any ],
206+ ) -> None :
207+ """Write a single row's output and trace fields into ``df`` in place."""
208+ df .at [index , "output" ] = output .output
209+
210+ for k , v in output .other_fields .items ():
211+ if k not in df .columns :
212+ df [k ] = None
213+ df .at [index , k ] = v
214+
215+ if trace :
216+ processed_trace , _ = tracer .post_process_trace (trace_obj = trace )
217+ df .at [index , "steps" ] = trace .to_dict ()
218+ if "latency" in processed_trace :
219+ df .at [index , "latency" ] = processed_trace ["latency" ]
220+ if "cost" in processed_trace :
221+ df .at [index , "cost" ] = processed_trace ["cost" ]
222+ if "tokens" in processed_trace :
223+ df .at [index , "tokens" ] = processed_trace ["tokens" ]
224+ if "context" in processed_trace :
225+ df .at [index , "context" ] = processed_trace ["context" ]
226+
227+ def _build_config (
228+ self , run_signature : inspect .Signature , df : pd .DataFrame
229+ ) -> Dict [str , Any ]:
230+ """Build the config dict returned alongside the output DataFrame."""
231+ config : Dict [str , Any ] = {
138232 "outputColumnName" : "output" ,
139233 "inputVariableNames" : list (run_signature .parameters .keys ()),
140234 "metadata" : {
@@ -154,7 +248,7 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
154248 for k , v in self .custom_args .items ():
155249 config ["metadata" ][k ] = v
156250
157- return df , config
251+ return config
158252
159253 def write_output_to_directory (
160254 self ,
0 commit comments