Skip to content

Commit 338cf5e

Browse files
base64 things
1 parent 7d64ddf commit 338cf5e

6 files changed

Lines changed: 1544 additions & 8 deletions

File tree

modelq/app/base.py

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from modelq.exceptions import TaskProcessingError, TaskTimeoutError,RetryTaskException
1616
from modelq.app.middleware import Middleware
1717
from modelq.app.redis_retry import _RedisWithRetry
18+
from modelq.app.utils.compression import compress_base64, decompress_base64
1819

1920
from pydantic import BaseModel, ValidationError
2021
from typing import Optional, Dict, Any, Type
@@ -33,6 +34,9 @@ class ModelQ:
3334
PRUNE_CHECK_INTERVAL = 60 # seconds: how often to check for stale servers
3435
TASK_RESULT_RETENTION = 86400
3536

37+
# Thread-local storage for tracking current task context
38+
_current_task = threading.local()
39+
3640
def __init__(
3741
self,
3842
host: str = "localhost",
@@ -609,6 +613,9 @@ def process_task(self, task: Task) -> None:
609613
f"with args: {call_args}, kwargs: {call_kwargs}"
610614
)
611615

616+
# Set current task context for this thread
617+
self._current_task.task_id = task.task_id
618+
612619
if stream:
613620
# Stream results
614621
for result in task_function(*call_args, **call_kwargs):
@@ -683,17 +690,35 @@ def process_task(self, task: Task) -> None:
683690

684691
finally:
685692
self.redis_client.srem("processing_tasks", task.task_id)
693+
# Clear current task context for this thread
694+
if hasattr(self._current_task, 'task_id'):
695+
delattr(self._current_task, 'task_id')
686696

687697

688698
def _store_final_task_state(self, task: Task, success: bool):
689699
"""
690700
Persists the final status/result of the task in Redis, adding finished_at.
701+
Preserves any base64_output that was stored during task execution.
691702
"""
703+
# Get existing task data to preserve base64_output if it exists
704+
existing_data = self.redis_client.get(f"task:{task.task_id}")
705+
existing_base64_output = None
706+
if existing_data:
707+
try:
708+
existing_dict = json.loads(existing_data)
709+
existing_base64_output = existing_dict.get("base64_output")
710+
except:
711+
pass
712+
692713
task_dict = task.to_dict()
693714

715+
# Preserve base64_output if it was stored during task execution
716+
if existing_base64_output is not None:
717+
task_dict["base64_output"] = existing_base64_output
718+
694719
# Mark finished_at
695720
task_dict["finished_at"] = time.time()
696-
721+
697722
self.redis_client.set(
698723
f"task_result:{task.task_id}",
699724
json.dumps(task_dict),
@@ -740,6 +765,148 @@ def get_task_status(self, task_id: str) -> Optional[str]:
740765
return json.loads(task_data).get("status")
741766
return None
742767

768+
def store_base64_output(
769+
self,
770+
base64_output: str,
771+
task_id: Optional[str] = None,
772+
compress: bool = True,
773+
compression_method: str = "zlib",
774+
compression_level: int = 6
775+
) -> bool:
776+
"""
777+
Store base64 output for a task with optional compression.
778+
Automatically detects the current task_id if called from within a task function.
779+
780+
Args:
781+
base64_output: The base64 encoded output (image, video, etc.)
782+
task_id: Optional task ID. If not provided, uses the current task being processed
783+
compress: Whether to compress the base64 output (default: True)
784+
compression_method: Compression algorithm to use (default: "zlib")
785+
Options: "zlib", "gzip", "bz2", "brotli", "lz4"
786+
compression_level: Compression level 0-9 (default: 6)
787+
Higher = better compression but slower
788+
789+
Returns:
790+
True if storage was successful, False otherwise
791+
792+
Example:
793+
@modelq.task()
794+
def generate_image(params):
795+
# ... generate image and encode to base64
796+
base64_image = "data:image/png;base64,..."
797+
798+
# Store with default compression (zlib, level 6)
799+
modelq.store_base64_output(base64_image, compress=True)
800+
801+
# Store with maximum compression using brotli
802+
modelq.store_base64_output(base64_image, compress=True,
803+
compression_method="brotli",
804+
compression_level=11)
805+
806+
# Return regular result
807+
return {"status": "success"}
808+
"""
809+
try:
810+
# Auto-detect task_id from current thread context if not provided
811+
if task_id is None:
812+
if hasattr(self._current_task, 'task_id'):
813+
task_id = self._current_task.task_id
814+
else:
815+
logger.error("store_base64_output called without task_id and no task context found")
816+
return False
817+
818+
# Get the existing task data
819+
task_data = self.redis_client.get(f"task:{task_id}")
820+
if not task_data:
821+
logger.warning(f"Task {task_id} not found when trying to store base64 output")
822+
return False
823+
824+
task_dict = json.loads(task_data)
825+
826+
# Compress if requested
827+
if compress:
828+
stored_output = compress_base64(
829+
base64_output,
830+
compression_level=compression_level,
831+
method=compression_method
832+
)
833+
else:
834+
stored_output = base64_output
835+
836+
# Update the task dict with base64_output
837+
task_dict["base64_output"] = stored_output
838+
839+
# Store back to Redis with same expiry times
840+
self.redis_client.set(
841+
f"task:{task_id}",
842+
json.dumps(task_dict),
843+
ex=86400 # 24 hours
844+
)
845+
846+
# Also update task_result if it exists
847+
task_result_data = self.redis_client.get(f"task_result:{task_id}")
848+
if task_result_data:
849+
result_dict = json.loads(task_result_data)
850+
result_dict["base64_output"] = stored_output
851+
self.redis_client.set(
852+
f"task_result:{task_id}",
853+
json.dumps(result_dict),
854+
ex=3600 # 1 hour
855+
)
856+
857+
logger.info(f"Stored base64 output for task {task_id} (compressed: {compress}, method: {compression_method})")
858+
return True
859+
860+
except Exception as e:
861+
logger.error(f"Failed to store base64 output for task {task_id}: {e}")
862+
return False
863+
864+
def get_task_base64(self, task_id: str, decompress: bool = True) -> Optional[str]:
865+
"""
866+
Retrieve the base64 output for a task.
867+
868+
Args:
869+
task_id: The task ID to retrieve the output for
870+
decompress: Whether to decompress the output if it was compressed (default: True)
871+
872+
Returns:
873+
The base64 output string, or None if not found
874+
875+
Example:
876+
# Get the base64 output for a completed task
877+
base64_image = modelq.get_task_base64(task.task_id)
878+
if base64_image:
879+
# Use the base64 image
880+
pass
881+
"""
882+
try:
883+
# Try to get from task_result first (most recent)
884+
task_data = self.redis_client.get(f"task_result:{task_id}")
885+
886+
# Fall back to task key if not in result
887+
if not task_data:
888+
task_data = self.redis_client.get(f"task:{task_id}")
889+
890+
if not task_data:
891+
logger.warning(f"Task {task_id} not found when trying to retrieve base64 output")
892+
return None
893+
894+
task_dict = json.loads(task_data)
895+
base64_output = task_dict.get("base64_output")
896+
897+
if base64_output is None:
898+
return None
899+
900+
# Decompress if requested
901+
if decompress:
902+
return decompress_base64(base64_output)
903+
else:
904+
return base64_output
905+
906+
except Exception as e:
907+
logger.error(f"Failed to retrieve base64 output for task {task_id}: {e}")
908+
return None
909+
743910
def log_task_error_to_file(self, task: Task, exc: Exception, file_path="modelq_errors.log"):
744911
"""
745912
Logs detailed error info to a specified file, with dashes before and after.

modelq/app/tasks/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def __init__(self, task_name: str, payload: dict, timeout: int = 15):
1818
self.original_payload = copy.deepcopy(payload)
1919
self.status = "queued"
2020
self.result = None
21-
21+
self.base64_output = None # New field for storing compressed base64 outputs
22+
2223
# New timestamps:
2324
self.created_at = time.time() # When Task object is instantiated
2425
self.queued_at = None # When task is enqueued in Redis
@@ -36,6 +37,7 @@ def to_dict(self):
3637
"payload": self.payload,
3738
"status": self.status,
3839
"result": self.result,
40+
"base64_output": self.base64_output,
3941
"created_at": self.created_at,
4042
"queued_at": self.queued_at,
4143
"started_at": self.started_at,
@@ -49,6 +51,7 @@ def from_dict(data: dict) -> "Task":
4951
task.task_id = data["task_id"]
5052
task.status = data["status"]
5153
task.result = data.get("result")
54+
task.base64_output = data.get("base64_output")
5255

5356
# Load timestamps if present
5457
task.created_at = data.get("created_at")

0 commit comments

Comments
 (0)