Skip to content

Commit 8c032d6

Browse files
committed
feat: allow custom models
1 parent 48d5430 commit 8c032d6

12 files changed

Lines changed: 475 additions & 64 deletions

File tree

README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,71 @@ async with DecartClient(api_key=os.getenv("DECART_API_KEY")) as client:
8989
f.write(data)
9090
```
9191

92+
### Custom Models
93+
94+
For preview, experimental, or private models that are not yet in the SDK registry,
95+
create a custom model definition and pass it directly to the matching API.
96+
`models.realtime(...)`, `models.video(...)`, and `models.image(...)` remain registry-only helpers;
97+
use `models.custom(...)` when you need to send an arbitrary model name.
98+
99+
```python
100+
from decart import DecartClient, RealtimeClient, RealtimeConnectOptions, models
101+
102+
# Realtime: default url_path is /v1/stream.
103+
custom_realtime_model = models.custom(
104+
"lucy_2_rt_preview",
105+
fps=20,
106+
width=1280,
107+
height=720,
108+
)
109+
110+
realtime_client = await RealtimeClient.connect(
111+
base_url=client.realtime_base_url,
112+
api_key=client.api_key,
113+
local_track=track,
114+
options=RealtimeConnectOptions(
115+
model=custom_realtime_model,
116+
on_remote_stream=lambda stream: print("remote stream", stream),
117+
),
118+
)
119+
120+
# Process API: use a generation endpoint; the default realtime url_path is
121+
# not valid for client.process().
122+
custom_image_model = models.custom(
123+
"lucy_image_preview",
124+
url_path="/v1/generate/lucy_image_preview",
125+
fps=25,
126+
width=1280,
127+
height=704,
128+
)
129+
130+
image = await client.process({
131+
"model": custom_image_model,
132+
"prompt": "Apply a preview model treatment",
133+
"data": open("input.png", "rb"),
134+
})
135+
136+
# Queue API: add queue_url_path for async jobs.
137+
custom_video_model = models.custom(
138+
"lucy_video_preview",
139+
url_path="/v1/generate/lucy_video_preview",
140+
queue_url_path="/v1/jobs/lucy_video_preview",
141+
fps=20,
142+
width=1280,
143+
height=720,
144+
)
145+
146+
job = await client.queue.submit({
147+
"model": custom_video_model,
148+
"prompt": "Use the preview video model",
149+
"data": open("input.mp4", "rb"),
150+
})
151+
```
152+
153+
If `input_schema` is omitted, custom process and queue inputs are sent through without
154+
client-side schema validation; the backend still validates whether the model name, endpoint,
155+
and inputs are supported.
156+
92157
## Development
93158

94159
### Setup with UV

decart/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
QueueResultError,
1313
TokenCreateError,
1414
)
15-
from .models import models, ModelDefinition, VideoRestyleInput
15+
from .models import models, ModelDefinition, CustomModelDefinition, VideoRestyleInput
1616
from .types import FileInput, ModelState, Prompt
1717
from .queue import (
1818
QueueClient,
@@ -69,6 +69,7 @@
6969
"QueueResultError",
7070
"models",
7171
"ModelDefinition",
72+
"CustomModelDefinition",
7273
"VideoRestyleInput",
7374
"FileInput",
7475
"ModelState",

decart/client.py

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
2+
from types import TracebackType
23
from typing import Any, Optional
34
import aiohttp
4-
from pydantic import ValidationError
5+
from pydantic import BaseModel, ValidationError
56
from .errors import InvalidAPIKeyError, InvalidBaseURLError, InvalidInputError
6-
from .models import ImageModelDefinition, _MODELS
7+
from .models import ModelDefinition, _MODELS
78
from .process.request import send_request
89
from .queue.client import QueueClient
910
from .tokens.client import TokensClient
@@ -128,25 +129,30 @@ async def close(self) -> None:
128129
if self._session and not self._session.closed:
129130
await self._session.close()
130131

131-
async def __aenter__(self):
132+
async def __aenter__(self) -> "DecartClient":
132133
"""Async context manager entry."""
133134
return self
134135

135-
async def __aexit__(self, exc_type, exc_val, exc_tb):
136+
async def __aexit__(
137+
self,
138+
exc_type: type[BaseException] | None,
139+
exc_val: BaseException | None,
140+
exc_tb: TracebackType | None,
141+
) -> None:
136142
"""Async context manager exit."""
137143
await self.close()
138144

139145
async def process(self, options: dict[str, Any]) -> bytes:
140146
"""
141147
Process image editing synchronously.
142-
Only image models support the process API.
148+
Image models and custom process model definitions support the process API.
143149
144150
For video editing, use the queue API instead:
145151
result = await client.queue.submit_and_poll({...})
146152
147153
Args:
148154
options: Processing options including model and inputs
149-
- model: ImageModelDefinition from models.image()
155+
- model: ImageModelDefinition from models.image(), or a custom definition from models.custom()
150156
- prompt: Text instructions describing the requested edit
151157
- Additional model-specific inputs
152158
@@ -160,14 +166,31 @@ async def process(self, options: dict[str, Any]) -> bytes:
160166
if "model" not in options:
161167
raise InvalidInputError("model is required")
162168

163-
model: ImageModelDefinition = options["model"]
169+
model: ModelDefinition[str] = options["model"]
164170

165-
# Validate that this is an image model (check against registry)
166-
if model.name not in _MODELS["image"]:
171+
# Keep known non-image registry definitions on their intended APIs, but
172+
# allow custom definitions with arbitrary names (even if a custom name
173+
# happens to overlap a registry name).
174+
is_known_image_model = model == _MODELS["image"].get(model.name)
175+
is_known_video_model = model == _MODELS["video"].get(model.name)
176+
is_known_realtime_model = model == _MODELS["realtime"].get(model.name)
177+
178+
if is_known_video_model or is_known_realtime_model:
179+
next_step = (
180+
"For realtime models, use RealtimeClient.connect() instead."
181+
if is_known_realtime_model
182+
else "For video models, use client.queue.submit_and_poll() instead."
183+
)
167184
raise InvalidInputError(
168185
f"Model '{model.name}' is not supported by process(). "
169-
f"Only image models support sync processing. "
170-
f"For video models, use client.queue.submit_and_poll() instead."
186+
f"Only image/custom process models support sync processing. "
187+
f"{next_step}"
188+
)
189+
190+
if not is_known_image_model and model.url_path == "/v1/stream":
191+
raise InvalidInputError(
192+
f"Custom process model '{model.name}' must define a process url_path. "
193+
f"Pass url_path when using models.custom(...) with client.process()."
171194
)
172195

173196
cancel_token = options.get("cancel_token")
@@ -181,22 +204,27 @@ async def process(self, options: dict[str, Any]) -> bytes:
181204
file_inputs = {k: v for k, v in inputs.items() if k in FILE_FIELDS}
182205
non_file_inputs = {k: v for k, v in inputs.items() if k not in FILE_FIELDS}
183206

184-
# Validate non-file inputs and create placeholder for file fields
185-
validation_inputs = {
186-
**non_file_inputs,
187-
**{k: b"" for k in file_inputs.keys()}, # Placeholder bytes for validation
188-
}
189-
190-
try:
191-
validated_inputs = model.input_schema(**validation_inputs)
192-
except ValidationError as e:
193-
raise InvalidInputError(f"Invalid inputs for {model.name}: {str(e)}") from e
194-
195-
# Build final inputs: validated non-file inputs + original file inputs
196-
processed_inputs = {
197-
**validated_inputs.model_dump(exclude_none=True),
198-
**file_inputs, # Override placeholders with actual file data
199-
}
207+
if model.input_schema is BaseModel:
208+
# Custom models can omit an input schema; in that case we pass
209+
# arbitrary fields through and let the backend validate them.
210+
processed_inputs = {k: v for k, v in inputs.items() if v is not None}
211+
else:
212+
# Validate non-file inputs and create placeholder for file fields
213+
validation_inputs = {
214+
**non_file_inputs,
215+
**{k: b"" for k in file_inputs.keys()}, # Placeholder bytes for validation
216+
}
217+
218+
try:
219+
validated_inputs = model.input_schema(**validation_inputs)
220+
except ValidationError as e:
221+
raise InvalidInputError(f"Invalid inputs for {model.name}: {str(e)}") from e
222+
223+
# Build final inputs: validated non-file inputs + original file inputs
224+
processed_inputs = {
225+
**validated_inputs.model_dump(exclude_none=True),
226+
**file_inputs, # Override placeholders with actual file data
227+
}
200228

201229
session = await self._get_session()
202230
response = await send_request(

decart/models.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from .errors import ModelNotFoundError
55
from .types import FileInput, MotionTrajectoryInput
66

7-
87
RealTimeModels = Literal[
98
# Canonical names
109
"lucy",
@@ -89,10 +88,11 @@ class DecartBaseModel(BaseModel):
8988
class ModelDefinition(DecartBaseModel, Generic[ModelT]):
9089
name: ModelT
9190
url_path: str
91+
queue_url_path: Optional[str] = None
9292
fps: int = Field(ge=1)
9393
width: int = Field(ge=1)
9494
height: int = Field(ge=1)
95-
input_schema: type[BaseModel]
95+
input_schema: type[BaseModel] = BaseModel
9696

9797

9898
# Type aliases for model definitions that support specific APIs
@@ -105,6 +105,13 @@ class ModelDefinition(DecartBaseModel, Generic[ModelT]):
105105
RealTimeModelDefinition = ModelDefinition[RealTimeModels]
106106
"""Type alias for model definitions that support realtime streaming."""
107107

108+
CustomModelDefinition = ModelDefinition[str]
109+
"""Type alias for custom model definitions with arbitrary model names.
110+
111+
Useful for preview, experimental, or private models that are not yet
112+
in the SDK's built-in registry.
113+
"""
114+
108115

109116
class VideoToVideoInput(DecartBaseModel):
110117
prompt: str = Field(
@@ -299,6 +306,7 @@ class ImageToImageInput(DecartBaseModel):
299306
"lucy-clip": ModelDefinition(
300307
name="lucy-clip",
301308
url_path="/v1/jobs/lucy-clip",
309+
queue_url_path="/v1/jobs/lucy-clip",
302310
fps=25,
303311
width=1280,
304312
height=704,
@@ -307,6 +315,7 @@ class ImageToImageInput(DecartBaseModel):
307315
"lucy-2.1": ModelDefinition(
308316
name="lucy-2.1",
309317
url_path="/v1/jobs/lucy-2.1",
318+
queue_url_path="/v1/jobs/lucy-2.1",
310319
fps=20,
311320
width=1088,
312321
height=624,
@@ -315,6 +324,7 @@ class ImageToImageInput(DecartBaseModel):
315324
"lucy-2.1-vton": ModelDefinition(
316325
name="lucy-2.1-vton",
317326
url_path="/v1/jobs/lucy-2.1-vton",
327+
queue_url_path="/v1/jobs/lucy-2.1-vton",
318328
fps=20,
319329
width=1088,
320330
height=624,
@@ -323,6 +333,7 @@ class ImageToImageInput(DecartBaseModel):
323333
"lucy-restyle-2": ModelDefinition(
324334
name="lucy-restyle-2",
325335
url_path="/v1/jobs/lucy-restyle-2",
336+
queue_url_path="/v1/jobs/lucy-restyle-2",
326337
fps=22,
327338
width=1280,
328339
height=704,
@@ -331,6 +342,7 @@ class ImageToImageInput(DecartBaseModel):
331342
"lucy-motion": ModelDefinition(
332343
name="lucy-motion",
333344
url_path="/v1/jobs/lucy-motion",
345+
queue_url_path="/v1/jobs/lucy-motion",
334346
fps=25,
335347
width=1280,
336348
height=704,
@@ -340,6 +352,7 @@ class ImageToImageInput(DecartBaseModel):
340352
"lucy-latest": ModelDefinition(
341353
name="lucy-latest",
342354
url_path="/v1/jobs/lucy-latest",
355+
queue_url_path="/v1/jobs/lucy-latest",
343356
fps=20,
344357
width=1088,
345358
height=624,
@@ -348,6 +361,7 @@ class ImageToImageInput(DecartBaseModel):
348361
"lucy-vton-latest": ModelDefinition(
349362
name="lucy-vton-latest",
350363
url_path="/v1/jobs/lucy-vton-latest",
364+
queue_url_path="/v1/jobs/lucy-vton-latest",
351365
fps=20,
352366
width=1088,
353367
height=624,
@@ -356,6 +370,7 @@ class ImageToImageInput(DecartBaseModel):
356370
"lucy-restyle-latest": ModelDefinition(
357371
name="lucy-restyle-latest",
358372
url_path="/v1/jobs/lucy-restyle-latest",
373+
queue_url_path="/v1/jobs/lucy-restyle-latest",
359374
fps=22,
360375
width=1280,
361376
height=704,
@@ -364,6 +379,7 @@ class ImageToImageInput(DecartBaseModel):
364379
"lucy-clip-latest": ModelDefinition(
365380
name="lucy-clip-latest",
366381
url_path="/v1/jobs/lucy-clip-latest",
382+
queue_url_path="/v1/jobs/lucy-clip-latest",
367383
fps=25,
368384
width=1280,
369385
height=704,
@@ -372,6 +388,7 @@ class ImageToImageInput(DecartBaseModel):
372388
"lucy-motion-latest": ModelDefinition(
373389
name="lucy-motion-latest",
374390
url_path="/v1/jobs/lucy-motion-latest",
391+
queue_url_path="/v1/jobs/lucy-motion-latest",
375392
fps=25,
376393
width=1280,
377394
height=704,
@@ -381,6 +398,7 @@ class ImageToImageInput(DecartBaseModel):
381398
"lucy-pro-v2v": ModelDefinition(
382399
name="lucy-pro-v2v",
383400
url_path="/v1/jobs/lucy-pro-v2v",
401+
queue_url_path="/v1/jobs/lucy-pro-v2v",
384402
fps=25,
385403
width=1280,
386404
height=704,
@@ -389,6 +407,7 @@ class ImageToImageInput(DecartBaseModel):
389407
"lucy-restyle-v2v": ModelDefinition(
390408
name="lucy-restyle-v2v",
391409
url_path="/v1/jobs/lucy-restyle-v2v",
410+
queue_url_path="/v1/jobs/lucy-restyle-v2v",
392411
fps=22,
393412
width=1280,
394413
height=704,
@@ -428,6 +447,41 @@ class ImageToImageInput(DecartBaseModel):
428447

429448

430449
class Models:
450+
@staticmethod
451+
def custom(
452+
name: str,
453+
*,
454+
fps: int,
455+
width: int,
456+
height: int,
457+
url_path: str = "/v1/stream",
458+
input_schema: type[BaseModel] = BaseModel,
459+
queue_url_path: Optional[str] = None,
460+
) -> CustomModelDefinition:
461+
"""Create a custom model definition with an arbitrary model name.
462+
463+
This is useful for preview, experimental, or private models that are
464+
not yet in the SDK's built-in registry. Pass the returned definition
465+
directly to the matching client API.
466+
467+
For realtime models, keep the default ``url_path="/v1/stream"``.
468+
For process/custom image models, pass the generation endpoint as
469+
``url_path``; the default realtime stream path is not valid for
470+
``client.process()``. For queue/custom video models, pass
471+
``queue_url_path``.
472+
If ``input_schema`` is omitted, process and queue inputs are sent
473+
through without client-side schema validation.
474+
"""
475+
return CustomModelDefinition(
476+
name=name,
477+
url_path=url_path,
478+
queue_url_path=queue_url_path,
479+
fps=fps,
480+
width=width,
481+
height=height,
482+
input_schema=input_schema,
483+
)
484+
431485
@staticmethod
432486
def realtime(model: RealTimeModels) -> RealTimeModelDefinition:
433487
"""Get a realtime model definition for WebRTC streaming."""

decart/process/request.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async def send_request(
8080
session: aiohttp.ClientSession,
8181
base_url: str,
8282
api_key: str,
83-
model: ModelDefinition,
83+
model: ModelDefinition[str],
8484
inputs: dict[str, Any],
8585
cancel_token: Optional[asyncio.Event] = None,
8686
integration: Optional[str] = None,

0 commit comments

Comments
 (0)