Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions openfeature/_event_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import threading
import typing
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger

from openfeature.event import (
EventDetails,
Expand All @@ -16,6 +18,8 @@
from openfeature.client import OpenFeatureClient


_logger = getLogger(__name__)

_global_lock = threading.RLock()
_global_handlers: dict[ProviderEvent, list[EventHandler]] = defaultdict(list)

Expand All @@ -24,19 +28,39 @@
defaultdict(lambda: defaultdict(list))
)

_executor_lock = threading.RLock()
_handler_executor = ThreadPoolExecutor(thread_name_prefix="openfeature-event-handler")


def _run_handler(handler: EventHandler, details: EventDetails) -> None:
try:
handler(details)
except Exception:
_logger.exception("OpenFeature event handler raised an exception")


def _submit_handler(handler: EventHandler, details: EventDetails) -> None:
with _executor_lock:
_handler_executor.submit(_run_handler, handler, details)


def _run_handlers(handlers: list[EventHandler], details: EventDetails) -> None:
for handler in handlers:
_submit_handler(handler, details)


def run_client_handlers(
client: OpenFeatureClient, event: ProviderEvent, details: EventDetails
) -> None:
with _client_lock:
for handler in _client_handlers[client][event]:
handler(details)
handlers = list(_client_handlers[client][event])
_run_handlers(handlers, details)


def run_global_handlers(event: ProviderEvent, details: EventDetails) -> None:
with _global_lock:
for handler in _global_handlers[event]:
handler(details)
handlers = list(_global_handlers[event])
_run_handlers(handlers, details)


def add_client_handler(
Expand Down Expand Up @@ -83,9 +107,17 @@ def run_handlers_for_provider(
run_global_handlers(event, details)
# run the handlers for clients associated to this provider
with _client_lock:
for client in _client_handlers:
if client.provider == provider:
run_client_handlers(client, event, details)
client_handlers_snapshot = [
(client, list(event_handlers[event]))
for client, event_handlers in _client_handlers.items()
]
handlers = [
handler
for client, event_list in client_handlers_snapshot
if client.provider == provider
for handler in event_list
]
_run_handlers(handlers, details)
Comment on lines 109 to +120
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing client.provider (which queries the provider registry and may acquire registry-level locks) while holding _client_lock introduces a risk of lock inversion and potential deadlocks if another thread holding the registry lock attempts to acquire _client_lock. To prevent this, copy the client handlers list under _client_lock first, and then perform the provider filtering outside of the lock.

    with _client_lock:
        client_handlers_snapshot = [
            (client, list(event_handlers[event]))
            for client, event_handlers in _client_handlers.items()
        ]
    handlers = [
        handler
        for client, event_list in client_handlers_snapshot
        if client.provider == provider
        for handler in event_list
    ]
    _run_handlers(handlers, details)



def _run_immediate_handler(
Expand All @@ -98,11 +130,21 @@ def _run_immediate_handler(
ProviderStatus.STALE: ProviderEvent.PROVIDER_STALE,
}
if event == status_to_event.get(client.get_provider_status()):
handler(EventDetails(provider_name=client.provider.get_metadata().name))
_submit_handler(
handler,
EventDetails(provider_name=client.provider.get_metadata().name),
)


def clear() -> None:
global _handler_executor
with _global_lock:
_global_handlers.clear()
with _client_lock:
_client_handlers.clear()
with _executor_lock:
old_executor = _handler_executor
_handler_executor = ThreadPoolExecutor(
thread_name_prefix="openfeature-event-handler"
)
Comment on lines +145 to +149
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling _handler_executor.shutdown(wait=True) while holding _executor_lock blocks any other thread trying to submit tasks or interact with the executor until all currently running tasks complete. To avoid blocking other threads unnecessarily, swap the executor reference inside the lock and call shutdown on the old executor outside of the lock.

Suggested change
with _executor_lock:
_handler_executor.shutdown(wait=True, cancel_futures=False)
_handler_executor = ThreadPoolExecutor(
thread_name_prefix="openfeature-event-handler"
)
with _executor_lock:
old_executor = _handler_executor
_handler_executor = ThreadPoolExecutor(
thread_name_prefix="openfeature-event-handler"
)
old_executor.shutdown(wait=True, cancel_futures=False)

old_executor.shutdown(wait=True, cancel_futures=False)
81 changes: 73 additions & 8 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from unittest.mock import MagicMock
import time
from threading import Event
from unittest.mock import MagicMock, call

import pytest

Expand Down Expand Up @@ -30,6 +32,19 @@
)


def _wait_for_call(mock: MagicMock, *args):
deadline = time.monotonic() + 1
expected_call = call(*args)
while time.monotonic() < deadline:
if mock.call_count == 1 and (not args or mock.call_args == expected_call):
return
time.sleep(0.01)
if args:
mock.assert_called_once_with(*args)
else:
mock.assert_called_once()


def test_should_not_raise_exception_with_noop_client():
# Given
# No provider has been set
Expand Down Expand Up @@ -293,10 +308,60 @@ def test_provider_events():

# Then
# NOTE: provider_ready is called immediately after adding the handler
spy.provider_ready.assert_called_once()
spy.provider_configuration_changed.assert_called_once_with(details)
spy.provider_error.assert_called_once_with(details)
spy.provider_stale.assert_called_once_with(details)
_wait_for_call(spy.provider_ready)
_wait_for_call(spy.provider_configuration_changed, details)
_wait_for_call(spy.provider_error, details)
_wait_for_call(spy.provider_stale, details)


def test_event_handler_error_does_not_stop_other_handlers():
# Given
provider = NoOpProvider()
set_provider(provider)
called = set()
second_handler_called = Event()

def raising_handler(details):
called.add("raising")
raise RuntimeError("boom")

def second_handler(details):
called.add("second")
second_handler_called.set()

add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, raising_handler)
add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, second_handler)

# When
provider.emit_provider_configuration_changed(ProviderEventDetails())

# Then
assert second_handler_called.wait(timeout=1)
assert called == {"raising", "second"}


def test_event_handlers_do_not_block_event_emitter():
# Given
provider = NoOpProvider()
set_provider(provider)
handler_started = Event()
release_handler = Event()

def slow_handler(details):
handler_started.set()
release_handler.wait(timeout=1)

add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, slow_handler)

# When
start = time.monotonic()
provider.emit_provider_configuration_changed(ProviderEventDetails())
elapsed = time.monotonic() - start

# Then
assert elapsed < 0.1
assert handler_started.wait(timeout=1)
release_handler.set()


def test_add_remove_event_handler():
Expand Down Expand Up @@ -333,7 +398,7 @@ def test_handlers_attached_to_provider_already_in_associated_state_should_run_im
add_handler(ProviderEvent.PROVIDER_READY, spy.provider_ready)

# Then
spy.provider_ready.assert_called_once()
_wait_for_call(spy.provider_ready)


def test_provider_ready_handlers_run_if_provider_initialize_function_terminates_normally():
Expand All @@ -348,7 +413,7 @@ def test_provider_ready_handlers_run_if_provider_initialize_function_terminates_
set_provider(provider)

# Then
spy.provider_ready.assert_called_once()
_wait_for_call(spy.provider_ready)


def test_provider_error_handlers_run_if_provider_initialize_function_terminates_abnormally():
Expand All @@ -363,7 +428,7 @@ def test_provider_error_handlers_run_if_provider_initialize_function_terminates_
set_provider(provider)

# Then
spy.provider_error.assert_called_once()
_wait_for_call(spy.provider_error)


def test_provider_status_is_updated_after_provider_emits_event():
Expand Down
30 changes: 22 additions & 8 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import types
import uuid
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import MagicMock
from unittest.mock import MagicMock, call

import pytest

Expand All @@ -29,6 +29,19 @@
from openfeature.transaction_context import ContextVarsTransactionContextPropagator


def _wait_for_call(mock: MagicMock, *args):
deadline = time.monotonic() + 1
expected_call = call(*args)
while time.monotonic() < deadline:
if mock.call_count == 1 and (not args or mock.call_args == expected_call):
return
time.sleep(0.01)
if args:
mock.assert_called_once_with(*args)
else:
mock.assert_called_once()


@pytest.mark.parametrize(
"flag_type, default_value, get_method",
(
Expand Down Expand Up @@ -467,10 +480,10 @@ def emit_all_events(provider):

# Then
# NOTE: provider_ready is called immediately after adding the handler
spy.provider_ready.assert_called_once()
spy.provider_configuration_changed.assert_called_once_with(details)
spy.provider_error.assert_called_once_with(details)
spy.provider_stale.assert_called_once_with(details)
_wait_for_call(spy.provider_ready)
_wait_for_call(spy.provider_configuration_changed, details)
_wait_for_call(spy.provider_error, details)
_wait_for_call(spy.provider_stale, details)


def test_add_remove_event_handler():
Expand Down Expand Up @@ -525,7 +538,7 @@ def test_provider_event_late_binding():
other_provider.emit_provider_configuration_changed(other_provider_details)

# Then
spy.provider_configuration_changed.assert_called_once_with(details)
_wait_for_call(spy.provider_configuration_changed, details)


# Requirement 5.1.4, Requirement 5.1.5
Expand All @@ -545,14 +558,15 @@ def test_provider_event_handler_exception():
)

# Then
spy.provider_error.assert_called_once_with(
_wait_for_call(
spy.provider_error,
EventDetails(
flags_changed=None,
message="some_error",
error_code=ErrorCode.GENERAL,
metadata={},
provider_name="No-op Provider",
)
),
)


Expand Down