@@ -25,21 +25,45 @@ def disconnect(self, handler):
2525
2626
2727class TestPosthogCeleryIntegration (unittest .TestCase ):
28- def test_instrument_is_idempotent (self ):
28+ SIGNAL_NAMES = (
29+ "task_prerun" ,
30+ "task_success" ,
31+ "task_failure" ,
32+ "task_retry" ,
33+ "before_task_publish" ,
34+ "after_task_publish" ,
35+ "worker_process_shutdown" ,
36+ )
37+
38+ def _build_fake_celery (self ):
2939 fake_signals = SimpleNamespace (
30- task_prerun = FakeSignal (),
31- task_success = FakeSignal (),
32- task_failure = FakeSignal (),
33- task_retry = FakeSignal (),
34- before_task_publish = FakeSignal (),
35- after_task_publish = FakeSignal (),
36- worker_process_shutdown = FakeSignal (),
40+ ** {signal_name : FakeSignal () for signal_name in self .SIGNAL_NAMES }
3741 )
38-
39- integration = PosthogCeleryIntegration ()
4042 fake_celery = ModuleType ("celery" )
4143 fake_celery .signals = fake_signals
4244 fake_celery .__version__ = "5.0.0"
45+ return fake_signals , fake_celery
46+
47+ def _assert_signal_counts (
48+ self , fake_signals , expected_connected , expected_disconnected = None
49+ ):
50+ for signal_name in self .SIGNAL_NAMES :
51+ self .assertEqual (
52+ len (getattr (fake_signals , signal_name ).connected ),
53+ expected_connected ,
54+ f"{ signal_name } connected count mismatch" ,
55+ )
56+ if expected_disconnected is not None :
57+ self .assertEqual (
58+ len (getattr (fake_signals , signal_name ).disconnected ),
59+ expected_disconnected ,
60+ f"{ signal_name } disconnected count mismatch" ,
61+ )
62+
63+ def test_instrument_is_idempotent (self ):
64+ fake_signals , fake_celery = self ._build_fake_celery ()
65+
66+ integration = PosthogCeleryIntegration ()
4367
4468 with (
4569 patch .dict ("sys.modules" , {"celery" : fake_celery }),
@@ -48,43 +72,18 @@ def test_instrument_is_idempotent(self):
4872 integration .instrument ()
4973 integration .instrument ()
5074
51- for sig in [
52- "task_prerun" ,
53- "task_success" ,
54- "task_failure" ,
55- "task_retry" ,
56- "before_task_publish" ,
57- "after_task_publish" ,
58- "worker_process_shutdown" ,
59- ]:
60- self .assertEqual (
61- len (getattr (fake_signals , sig ).connected ),
62- 1 ,
63- f"{ sig } connected multiple times" ,
64- )
75+ self ._assert_signal_counts (fake_signals , expected_connected = 1 )
6576 mock_register .assert_called_once_with (integration .shutdown )
6677 self .assertEqual (
6778 fake_signals .worker_process_shutdown .connected [0 ][0 ],
6879 integration ._on_worker_process_shutdown ,
6980 )
7081
7182 def test_instrument_and_uninstrument_connect_signals (self ):
72- fake_signals = SimpleNamespace (
73- task_prerun = FakeSignal (),
74- task_success = FakeSignal (),
75- task_failure = FakeSignal (),
76- task_retry = FakeSignal (),
77- before_task_publish = FakeSignal (),
78- after_task_publish = FakeSignal (),
79- worker_process_shutdown = FakeSignal (),
80- )
83+ fake_signals , fake_celery = self ._build_fake_celery ()
8184
8285 integration = PosthogCeleryIntegration ()
8386
84- fake_celery = ModuleType ("celery" )
85- fake_celery .signals = fake_signals
86- fake_celery .__version__ = "5.0.0"
87-
8887 with (
8988 patch .dict ("sys.modules" , {"celery" : fake_celery }),
9089 patch ("posthog.integrations.celery.atexit.register" ),
@@ -93,23 +92,9 @@ def test_instrument_and_uninstrument_connect_signals(self):
9392 integration .instrument ()
9493 integration .uninstrument ()
9594
96- for sig in [
97- "task_prerun" ,
98- "task_success" ,
99- "task_failure" ,
100- "task_retry" ,
101- "before_task_publish" ,
102- "after_task_publish" ,
103- "worker_process_shutdown" ,
104- ]:
105- self .assertEqual (
106- len (getattr (fake_signals , sig ).connected ), 1 , f"{ sig } not connected"
107- )
108- self .assertEqual (
109- len (getattr (fake_signals , sig ).disconnected ),
110- 1 ,
111- f"{ sig } not disconnected" ,
112- )
95+ self ._assert_signal_counts (
96+ fake_signals , expected_connected = 1 , expected_disconnected = 1
97+ )
11398 mock_unregister .assert_called_once_with (integration .shutdown )
11499 self .assertEqual (
115100 fake_signals .worker_process_shutdown .disconnected [0 ],
@@ -162,22 +147,11 @@ def test_shutdown_is_idempotent(self):
162147 mock_client .flush .assert_called_once_with ()
163148
164149 def test_shutdown_keeps_atexit_registration_when_flush_fails (self ):
165- fake_signals = SimpleNamespace (
166- task_prerun = FakeSignal (),
167- task_success = FakeSignal (),
168- task_failure = FakeSignal (),
169- task_retry = FakeSignal (),
170- before_task_publish = FakeSignal (),
171- after_task_publish = FakeSignal (),
172- worker_process_shutdown = FakeSignal (),
173- )
150+ fake_signals , fake_celery = self ._build_fake_celery ()
174151
175152 mock_client = Mock ()
176153 mock_client .flush .side_effect = RuntimeError ("flush failed" )
177154 integration = PosthogCeleryIntegration (client = mock_client )
178- fake_celery = ModuleType ("celery" )
179- fake_celery .signals = fake_signals
180- fake_celery .__version__ = "5.0.0"
181155
182156 with (
183157 patch .dict ("sys.modules" , {"celery" : fake_celery }),
@@ -194,21 +168,10 @@ def test_shutdown_keeps_atexit_registration_when_flush_fails(self):
194168 self .assertEqual (len (fake_signals .worker_process_shutdown .disconnected ), 1 )
195169
196170 def test_reinstrument_after_shutdown_allows_shutdown_again (self ):
197- fake_signals = SimpleNamespace (
198- task_prerun = FakeSignal (),
199- task_success = FakeSignal (),
200- task_failure = FakeSignal (),
201- task_retry = FakeSignal (),
202- before_task_publish = FakeSignal (),
203- after_task_publish = FakeSignal (),
204- worker_process_shutdown = FakeSignal (),
205- )
171+ fake_signals , fake_celery = self ._build_fake_celery ()
206172
207173 mock_client = Mock ()
208174 integration = PosthogCeleryIntegration (client = mock_client )
209- fake_celery = ModuleType ("celery" )
210- fake_celery .signals = fake_signals
211- fake_celery .__version__ = "5.0.0"
212175
213176 with (
214177 patch .dict ("sys.modules" , {"celery" : fake_celery }),
@@ -223,23 +186,9 @@ def test_reinstrument_after_shutdown_allows_shutdown_again(self):
223186 self .assertEqual (mock_client .flush .call_count , 2 )
224187 self .assertEqual (mock_register .call_count , 2 )
225188 self .assertEqual (mock_unregister .call_count , 2 )
226- for sig in [
227- "task_prerun" ,
228- "task_success" ,
229- "task_failure" ,
230- "task_retry" ,
231- "before_task_publish" ,
232- "after_task_publish" ,
233- "worker_process_shutdown" ,
234- ]:
235- self .assertEqual (
236- len (getattr (fake_signals , sig ).connected ), 2 , f"{ sig } not reconnected"
237- )
238- self .assertEqual (
239- len (getattr (fake_signals , sig ).disconnected ),
240- 2 ,
241- f"{ sig } not disconnected twice" ,
242- )
189+ self ._assert_signal_counts (
190+ fake_signals , expected_connected = 2 , expected_disconnected = 2
191+ )
243192
244193 def test_worker_process_shutdown_hook_calls_shutdown (self ):
245194 integration = PosthogCeleryIntegration (client = Mock ())
0 commit comments