Skip to content

Commit a5cf631

Browse files
some basic restructure
1 parent 5555dfe commit a5cf631

6 files changed

Lines changed: 645 additions & 867 deletions

File tree

modelq/app/backends/base.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional, Dict, Any
3+
4+
5+
class QueueBackend(ABC):
6+
"""Abstract interface for ModelQ queue backends."""
7+
8+
# ─── Task Enqueue/Dequeue ──────────────────────────────────────────────
9+
10+
@abstractmethod
11+
def enqueue_task(self, task_data: dict) -> None:
12+
"""Push a new task to the task queue."""
13+
pass
14+
15+
@abstractmethod
16+
def dequeue_task(self, timeout: Optional[int] = None) -> Optional[dict]:
17+
"""Pop the next task from the queue (blocking or timed)."""
18+
pass
19+
20+
@abstractmethod
21+
def requeue_task(self, task_data: dict) -> None:
22+
"""Re-queue an existing task (e.g., after failure or rejection)."""
23+
pass
24+
25+
@abstractmethod
26+
def enqueue_delayed_task(self, task_data: dict, delay_seconds: int) -> None:
27+
"""Push task to delayed queue (sorted by run timestamp)."""
28+
pass
29+
30+
@abstractmethod
31+
def dequeue_ready_delayed_tasks(self) -> list:
32+
"""Get all delayed tasks ready to run now (score <= time.time())."""
33+
pass
34+
35+
@abstractmethod
36+
def flush_queue(self) -> None:
37+
"""Empty all tasks from the main task queue (for tests/dev reset)."""
38+
pass
39+
40+
# ─── Task Status Management ────────────────────────────────────────────
41+
42+
@abstractmethod
43+
def save_task_state(self, task_id: str, task_data: dict, result: bool) -> None:
44+
"""Save or update the final state of a task (completed/failed/etc)."""
45+
pass
46+
47+
@abstractmethod
48+
def load_task_state(self, task_id: str) -> Optional[dict]:
49+
"""Fetch a task's full state from storage."""
50+
pass
51+
52+
@abstractmethod
53+
def remove_task_from_queue(self, task_id: str) -> bool:
54+
"""Remove task from queue if still queued."""
55+
pass
56+
57+
@abstractmethod
58+
def mark_processing(self, task_id: str) -> bool:
59+
"""Add task to 'processing' set; return False if already processing."""
60+
pass
61+
62+
@abstractmethod
63+
def unmark_processing(self, task_id: str) -> None:
64+
"""Remove task from processing set."""
65+
pass
66+
67+
@abstractmethod
68+
def get_all_processing_tasks(self) -> list:
69+
"""Return list of currently 'processing' task IDs."""
70+
pass
71+
72+
@abstractmethod
73+
def get_all_queued_tasks(self) -> list:
74+
"""Return list of all tasks in the main queue."""
75+
pass
76+
77+
# ─── Server Registry ───────────────────────────────────────────────────
78+
79+
@abstractmethod
80+
def register_server(self, server_id: str, task_names: list) -> None:
81+
"""Register a worker with allowed task names and heartbeat."""
82+
pass
83+
84+
@abstractmethod
85+
def update_server_status(self, server_id: str, status: str) -> None:
86+
"""Update current server status and heartbeat time."""
87+
pass
88+
89+
@abstractmethod
90+
def get_all_server_ids(self) -> list:
91+
"""Return all currently registered server IDs."""
92+
pass
93+
94+
@abstractmethod
95+
def get_server_data(self, server_id: str) -> Optional[dict]:
96+
"""Get full data object for a server."""
97+
pass
98+
99+
@abstractmethod
100+
def prune_dead_servers(self, timeout: int) -> list:
101+
"""Remove any servers whose heartbeat is older than `timeout` seconds."""
102+
pass
103+
104+
# ─── Metrics + Maintenance ─────────────────────────────────────────────
105+
106+
@abstractmethod
107+
def prune_old_results(self, older_than_seconds: int) -> int:
108+
"""Delete old task results beyond TTL."""
109+
pass
110+
111+
@abstractmethod
112+
def queue_length(self) -> int:
113+
"""Return the length of the main task queue."""
114+
pass
115+
116+
@abstractmethod
117+
def cleanup_dlq(self) -> None:
118+
"""Clear all items from dead letter queue."""
119+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from modelq.app.backends.redis.backend import RedisQueueBackend
2+
3+
__all__ = ["RedisQueueBackend"]
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import time
2+
import json
3+
import redis
4+
from typing import Optional
5+
from modelq.app.backends.base import QueueBackend
6+
7+
8+
class RedisQueueBackend(QueueBackend):
9+
def __init__(self, redis_client: redis.Redis):
10+
self.redis = redis_client
11+
12+
# ─────────────────────── Task Queue ───────────────────────────────
13+
14+
def enqueue_task(self, task_data: dict) -> None:
15+
task_data["status"] = "queued"
16+
self.redis.rpush("ml_tasks", json.dumps(task_data))
17+
self.redis.zadd("queued_requests", {task_data["task_id"]: task_data["queued_at"]})
18+
19+
def dequeue_task(self, timeout: Optional[int] = None) -> Optional[dict]:
20+
data = self.redis.blpop("ml_tasks", timeout=timeout or 5)
21+
if data:
22+
_, task_json = data
23+
return json.loads(task_json)
24+
return None
25+
26+
def requeue_task(self, task_data: dict) -> None:
27+
self.redis.rpush("ml_tasks", json.dumps(task_data))
28+
29+
def enqueue_delayed_task(self, task_data: dict, delay_seconds: int) -> None:
30+
run_at = time.time() + delay_seconds
31+
self.redis.zadd("delayed_tasks", {json.dumps(task_data): run_at})
32+
33+
def dequeue_ready_delayed_tasks(self) -> list:
34+
now = time.time()
35+
tasks = self.redis.zrangebyscore("delayed_tasks", 0, now)
36+
for task_json in tasks:
37+
self.redis.zrem("delayed_tasks", task_json)
38+
self.redis.lpush("ml_tasks", task_json)
39+
return [json.loads(t) for t in tasks]
40+
41+
def flush_queue(self) -> None:
42+
self.redis.ltrim("ml_tasks", 1, 0)
43+
44+
# ─────────────────────── Task State ───────────────────────────────
45+
46+
def save_task_state(self, task_id: str, task_data: dict, result: bool) -> None:
47+
task_data["finished_at"] = time.time()
48+
self.redis.set(f"task_result:{task_id}", json.dumps(task_data), ex=3600)
49+
self.redis.set(f"task:{task_id}", json.dumps(task_data), ex=86400)
50+
51+
def load_task_state(self, task_id: str) -> Optional[dict]:
52+
data = self.redis.get(f"task:{task_id}")
53+
return json.loads(data) if data else None
54+
55+
def remove_task_from_queue(self, task_id: str) -> bool:
56+
tasks = self.redis.lrange("ml_tasks", 0, -1)
57+
for task_json in tasks:
58+
task_dict = json.loads(task_json)
59+
if task_dict.get("task_id") == task_id:
60+
self.redis.lrem("ml_tasks", 1, task_json)
61+
self.redis.zrem("queued_requests", task_id)
62+
return True
63+
return False
64+
65+
def mark_processing(self, task_id: str) -> bool:
66+
return self.redis.sadd("processing_tasks", task_id) == 1
67+
68+
def unmark_processing(self, task_id: str) -> None:
69+
self.redis.srem("processing_tasks", task_id)
70+
71+
def get_all_processing_tasks(self) -> list:
72+
return [pid.decode() for pid in self.redis.smembers("processing_tasks")]
73+
74+
def get_all_queued_tasks(self) -> list:
75+
raw = self.redis.lrange("ml_tasks", 0, -1)
76+
return [json.loads(task) for task in raw if json.loads(task).get("status") == "queued"]
77+
78+
# ─────────────────────── Server State ───────────────────────────────
79+
80+
def register_server(self, server_id: str, task_names: list) -> None:
81+
self.redis.hset("servers", server_id, json.dumps({
82+
"allowed_tasks": task_names,
83+
"status": "idle",
84+
"last_heartbeat": time.time()
85+
}))
86+
87+
def update_server_status(self, server_id: str, status: str) -> None:
88+
raw = self.redis.hget("servers", server_id)
89+
if raw:
90+
data = json.loads(raw)
91+
data["status"] = status
92+
data["last_heartbeat"] = time.time()
93+
self.redis.hset("servers", server_id, json.dumps(data))
94+
95+
def get_all_server_ids(self) -> list:
96+
return [k.decode("utf-8") for k in self.redis.hkeys("servers")]
97+
98+
def get_server_data(self, server_id: str) -> Optional[dict]:
99+
raw = self.redis.hget("servers", server_id)
100+
return json.loads(raw) if raw else None
101+
102+
def prune_dead_servers(self, timeout: int) -> list:
103+
now = time.time()
104+
pruned = []
105+
for sid, raw in self.redis.hgetall("servers").items():
106+
try:
107+
sid_str = sid.decode()
108+
data = json.loads(raw)
109+
if now - data.get("last_heartbeat", 0) > timeout:
110+
self.redis.hdel("servers", sid_str)
111+
pruned.append(sid_str)
112+
except:
113+
continue
114+
return pruned
115+
116+
# ─────────────────────── Miscellaneous ─────────────────────────────
117+
118+
def prune_old_results(self, older_than_seconds: int) -> int:
119+
now = time.time()
120+
deleted = 0
121+
for key in self.redis.scan_iter("task_result:*"):
122+
raw = self.redis.get(key)
123+
if not raw:
124+
continue
125+
data = json.loads(raw)
126+
timestamp = data.get("finished_at") or data.get("started_at")
127+
if timestamp and now - timestamp > older_than_seconds:
128+
task_id = key.decode().split(":")[-1]
129+
self.redis.delete(key)
130+
self.redis.delete(f"task:{task_id}")
131+
deleted += 1
132+
return deleted
133+
134+
def queue_length(self) -> int:
135+
return self.redis.llen("ml_tasks")
136+
137+
def cleanup_dlq(self) -> None:
138+
self.redis.delete("dlq")

0 commit comments

Comments
 (0)