Skip to content

Commit d2e32e0

Browse files
v1.0.6
1 parent e60df3d commit d2e32e0

4 files changed

Lines changed: 234 additions & 5 deletions

File tree

modelq/app/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,8 @@ def decorator(func):
536536
def wrapper(*args, **kwargs):
537537
# Extract optional custom task_id from kwargs
538538
custom_task_id = kwargs.pop('_task_id', None)
539+
# Extract optional additional_params from kwargs
540+
additional_params = kwargs.pop('additional_params', None)
539541

540542
# --------------------------- PRODUCER-SIDE VALIDATION
541543
if schema is not None: # ▶ pydantic
@@ -562,7 +564,7 @@ def wrapper(*args, **kwargs):
562564
"retries": retries,
563565
}
564566

565-
task = task_class(task_name=func.__name__, payload=payload, task_id=custom_task_id)
567+
task = task_class(task_name=func.__name__, payload=payload, task_id=custom_task_id, additional_params=additional_params)
566568
if stream:
567569
task.stream = True
568570

modelq/app/tasks/base.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Type
1212

1313
class Task:
14-
def __init__(self, task_name: str, payload: dict, timeout: int = 15, task_id: Optional[str] = None):
14+
def __init__(self, task_name: str, payload: dict, timeout: int = 15, task_id: Optional[str] = None, additional_params: Optional[Dict[str, Any]] = None):
1515
self.task_id = task_id if task_id else str(uuid.uuid4())
1616
self.task_name = task_name
1717
self.payload = payload
@@ -29,8 +29,11 @@ def __init__(self, task_name: str, payload: dict, timeout: int = 15, task_id: Op
2929
self.stream = False
3030
self.combined_result = ""
3131

32+
# Additional parameters that will be included in task response
33+
self.additional_params = additional_params or {}
34+
3235
def to_dict(self):
33-
return {
36+
base_dict = {
3437
"task_id": self.task_id,
3538
"task_name": self.task_name,
3639
"payload": self.payload,
@@ -42,10 +45,25 @@ def to_dict(self):
4245
"finished_at": self.finished_at,
4346
"stream": self.stream,
4447
}
48+
# Add additional_params to the base response if they exist
49+
if self.additional_params:
50+
base_dict.update(self.additional_params)
51+
return base_dict
4552

4653
@staticmethod
4754
def from_dict(data: dict) -> "Task":
48-
task = Task(task_name=data["task_name"], payload=data["payload"])
55+
# Extract additional_params from data (any keys not in the standard set)
56+
standard_keys = {
57+
"task_id", "task_name", "payload", "status", "result",
58+
"created_at", "queued_at", "started_at", "finished_at", "stream"
59+
}
60+
additional_params = {k: v for k, v in data.items() if k not in standard_keys}
61+
62+
task = Task(
63+
task_name=data["task_name"],
64+
payload=data["payload"],
65+
additional_params=additional_params if additional_params else None
66+
)
4967
task.task_id = data["task_id"]
5068
task.status = data["status"]
5169
task.result = data.get("result")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "modelq"
3-
version = "1.0.5"
3+
version = "1.0.6"
44
description = "Celery-like task queue for ML inference."
55
authors = ["Tanmaypatil123 <tanmay@modelslab.com>"]
66
readme = "README.md"

tests/test_base.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,3 +619,212 @@ def test_configurable_history_retention(mock_redis):
619619
# Custom retention of 1 hour
620620
mq = ModelQ(redis_client=mock_redis, task_history_retention=3600)
621621
assert mq.task_history_retention == 3600
622+
623+
624+
# ---------------------------------------------------------------------------
625+
# Additional Params Tests
626+
# ---------------------------------------------------------------------------
627+
628+
def test_task_with_additional_params(modelq_instance):
629+
"""Test that tasks can include additional custom parameters."""
630+
631+
@modelq_instance.task()
632+
def test_task_with_params(x, y):
633+
return x + y
634+
635+
# Call task with additional_params
636+
task = test_task_with_params(
637+
5, 10,
638+
additional_params={
639+
"proxy_links": ["http://proxy1.example.com", "http://proxy2.example.com"],
640+
"public_links": "http://public.example.com",
641+
"custom_field": "custom_value"
642+
}
643+
)
644+
645+
# Verify task has additional_params
646+
assert task.additional_params == {
647+
"proxy_links": ["http://proxy1.example.com", "http://proxy2.example.com"],
648+
"public_links": "http://public.example.com",
649+
"custom_field": "custom_value"
650+
}
651+
652+
# Verify to_dict includes additional_params at root level
653+
task_dict = task.to_dict()
654+
assert task_dict["proxy_links"] == ["http://proxy1.example.com", "http://proxy2.example.com"]
655+
assert task_dict["public_links"] == "http://public.example.com"
656+
assert task_dict["custom_field"] == "custom_value"
657+
658+
# Verify standard fields are still present
659+
assert task_dict["task_id"] == task.task_id
660+
assert task_dict["task_name"] == "test_task_with_params"
661+
assert task_dict["status"] == "queued"
662+
663+
664+
def test_task_without_additional_params(modelq_instance):
665+
"""Test that tasks work normally without additional_params."""
666+
667+
@modelq_instance.task()
668+
def test_task_no_params(x):
669+
return x * 2
670+
671+
# Call task without additional_params
672+
task = test_task_no_params(5)
673+
674+
# Verify task has empty additional_params
675+
assert task.additional_params == {}
676+
677+
# Verify to_dict doesn't include extra fields
678+
task_dict = task.to_dict()
679+
assert "proxy_links" not in task_dict
680+
assert "public_links" not in task_dict
681+
682+
# Verify standard fields are still present
683+
assert task_dict["task_id"] == task.task_id
684+
assert task_dict["task_name"] == "test_task_no_params"
685+
assert task_dict["status"] == "queued"
686+
687+
688+
def test_task_additional_params_in_redis(modelq_instance):
689+
"""Test that additional_params are stored correctly in Redis."""
690+
691+
@modelq_instance.task()
692+
def redis_task(value):
693+
return value
694+
695+
# Create task with additional_params
696+
task = redis_task(
697+
"test",
698+
additional_params={
699+
"metadata": {"key": "value"},
700+
"tags": ["tag1", "tag2"]
701+
}
702+
)
703+
704+
# Retrieve task from Redis
705+
task_json = modelq_instance.redis_client.get(f"task:{task.task_id}")
706+
assert task_json is not None
707+
708+
stored_task = _json_bytes_to_dict(task_json)
709+
710+
# Verify additional_params are stored
711+
assert stored_task["metadata"] == {"key": "value"}
712+
assert stored_task["tags"] == ["tag1", "tag2"]
713+
714+
715+
def test_task_from_dict_with_additional_params(modelq_instance):
716+
"""Test that Task.from_dict correctly restores additional_params."""
717+
from modelq.app.tasks.base import Task
718+
719+
# Create a task dict with additional params
720+
task_dict = {
721+
"task_id": "test_123",
722+
"task_name": "test_task",
723+
"payload": {"data": {}},
724+
"status": "queued",
725+
"result": None,
726+
"created_at": time.time(),
727+
"queued_at": time.time(),
728+
"started_at": None,
729+
"finished_at": None,
730+
"stream": False,
731+
"proxy_links": ["http://proxy.example.com"],
732+
"public_links": "http://public.example.com",
733+
"custom_metadata": {"foo": "bar"}
734+
}
735+
736+
# Restore task from dict
737+
task = Task.from_dict(task_dict)
738+
739+
# Verify additional_params are restored
740+
assert task.additional_params == {
741+
"proxy_links": ["http://proxy.example.com"],
742+
"public_links": "http://public.example.com",
743+
"custom_metadata": {"foo": "bar"}
744+
}
745+
746+
# Verify to_dict produces same output
747+
restored_dict = task.to_dict()
748+
assert restored_dict["proxy_links"] == ["http://proxy.example.com"]
749+
assert restored_dict["public_links"] == "http://public.example.com"
750+
assert restored_dict["custom_metadata"] == {"foo": "bar"}
751+
752+
753+
def test_task_additional_params_empty_dict(modelq_instance):
754+
"""Test that passing empty additional_params works correctly."""
755+
756+
@modelq_instance.task()
757+
def empty_params_task(x):
758+
return x
759+
760+
# Call with empty dict
761+
task = empty_params_task(5, additional_params={})
762+
763+
assert task.additional_params == {}
764+
task_dict = task.to_dict()
765+
766+
# Should only have standard fields
767+
expected_keys = {
768+
"task_id", "task_name", "payload", "status", "result",
769+
"created_at", "queued_at", "started_at", "finished_at", "stream"
770+
}
771+
assert set(task_dict.keys()) == expected_keys
772+
773+
774+
def test_task_additional_params_with_pydantic(mock_redis):
775+
"""Test that additional_params work with Pydantic schema validation."""
776+
from pydantic import BaseModel
777+
778+
class TaskInput(BaseModel):
779+
name: str
780+
value: int
781+
782+
mq = ModelQ(redis_client=mock_redis)
783+
784+
@mq.task(schema=TaskInput)
785+
def pydantic_task(params):
786+
return params.name
787+
788+
# Create task with Pydantic input and additional_params
789+
task = pydantic_task(
790+
TaskInput(name="test", value=42),
791+
additional_params={
792+
"request_id": "req_123",
793+
"priority": "high"
794+
}
795+
)
796+
797+
assert task.additional_params == {
798+
"request_id": "req_123",
799+
"priority": "high"
800+
}
801+
802+
task_dict = task.to_dict()
803+
assert task_dict["request_id"] == "req_123"
804+
assert task_dict["priority"] == "high"
805+
806+
807+
def test_get_task_details_with_additional_params(modelq_instance):
808+
"""Test that get_task_details returns additional_params."""
809+
task_data = {
810+
"task_id": "detail_with_params",
811+
"task_name": "test_task",
812+
"status": "queued",
813+
"created_at": time.time(),
814+
"proxy_links": ["http://proxy.example.com"],
815+
"public_links": "http://public.example.com"
816+
}
817+
818+
modelq_instance.redis_client.set(
819+
"task_history:detail_with_params",
820+
json.dumps(task_data)
821+
)
822+
modelq_instance.redis_client.zadd(
823+
"task_history",
824+
{"detail_with_params": task_data["created_at"]}
825+
)
826+
827+
details = modelq_instance.get_task_details("detail_with_params")
828+
assert details is not None
829+
assert details["proxy_links"] == ["http://proxy.example.com"]
830+
assert details["public_links"] == "http://public.example.com"

0 commit comments

Comments
 (0)