diff --git a/hazelcast/asyncio/__init__.py b/hazelcast/asyncio/__init__.py index 98bb081f28..0cdefdc6e5 100644 --- a/hazelcast/asyncio/__init__.py +++ b/hazelcast/asyncio/__init__.py @@ -15,6 +15,7 @@ "ReliableTopic", "ReplicatedMap", "Ringbuffer", + "Semaphore", "Set", "VectorCollection", ] @@ -35,3 +36,4 @@ from hazelcast.internal.asyncio_proxy.atomic_long import AtomicLong from hazelcast.internal.asyncio_proxy.atomic_reference import AtomicReference from hazelcast.internal.asyncio_proxy.countdown_latch import CountDownLatch +from hazelcast.internal.asyncio_proxy.semaphore import Semaphore diff --git a/hazelcast/internal/asyncio_client.py b/hazelcast/internal/asyncio_client.py index c266a42a59..e1fdd0cfce 100644 --- a/hazelcast/internal/asyncio_client.py +++ b/hazelcast/internal/asyncio_client.py @@ -11,6 +11,7 @@ from hazelcast.discovery import HazelcastCloudAddressProvider from hazelcast.errors import IllegalStateError, InvalidConfigurationError from hazelcast.internal.asyncio_invocation import InvocationService, Invocation +from hazelcast.internal.asyncio_proxy.cp import ProxySessionManager from hazelcast.internal.asyncio_proxy.cp_manager import CPSubsystem from hazelcast.internal.asyncio_proxy.pn_counter import PNCounter from hazelcast.internal.asyncio_proxy.vector_collection import VectorCollection @@ -209,6 +210,7 @@ def __init__(self, config: Config | None = None, **kwargs): ) self._proxy_manager = ProxyManager(self._context) self._cp_subsystem = CPSubsystem(self._context) + self._proxy_session_manager = ProxySessionManager(self._context) self._lock_reference_id_generator = AtomicInteger(1) self._statistics = Statistics( self, @@ -248,6 +250,7 @@ def _init_context(self): self._near_cache_manager, self._lock_reference_id_generator, self._name, + self._proxy_session_manager, self._reactor, self._compact_schema_service, ) @@ -493,6 +496,7 @@ async def shutdown(self) -> None: if self._internal_lifecycle_service.running: self._internal_lifecycle_service.fire_lifecycle_event(LifecycleState.SHUTTING_DOWN) self._internal_lifecycle_service.shutdown() + await self._proxy_session_manager.shutdown() self._near_cache_manager.destroy_near_caches() await self._connection_manager.shutdown() self._invocation_service.shutdown() @@ -590,6 +594,7 @@ def __init__(self): self.near_cache_manager = None self.lock_reference_id_generator = None self.name = None + self.proxy_session_manager = None self.reactor = None self.compact_schema_service = None @@ -607,6 +612,7 @@ def init_context( near_cache_manager, lock_reference_id_generator, name, + proxy_session_manager, reactor, compact_schema_service, ): @@ -622,5 +628,6 @@ def init_context( self.near_cache_manager = near_cache_manager self.lock_reference_id_generator = lock_reference_id_generator self.name = name + self.proxy_session_manager = proxy_session_manager self.reactor = reactor self.compact_schema_service = compact_schema_service diff --git a/hazelcast/internal/asyncio_proxy/cp.py b/hazelcast/internal/asyncio_proxy/cp.py index 333500a641..0138ebcbe6 100644 --- a/hazelcast/internal/asyncio_proxy/cp.py +++ b/hazelcast/internal/asyncio_proxy/cp.py @@ -1,5 +1,21 @@ +import abc +import asyncio + +from hazelcast.cp import _SessionState +from hazelcast.errors import ( + HazelcastClientNotActiveError, + SessionExpiredError, + CPGroupDestroyedError, +) from hazelcast.internal.asyncio_invocation import Invocation -from hazelcast.protocol.codec import cp_group_destroy_cp_object_codec +from hazelcast.protocol import RaftGroupId +from hazelcast.protocol.codec import ( + cp_group_destroy_cp_object_codec, + cp_session_generate_thread_id_codec, + cp_session_close_session_codec, + cp_session_create_session_codec, + cp_session_heartbeat_session_codec, +) def _no_op_response_handler(_): @@ -32,3 +48,182 @@ def _invoke(self, request, response_handler=_no_op_response_handler): async def _ainvoke(self, request, response_handler=_no_op_response_handler): fut = self._invoke(request, response_handler) return await fut + + +class SessionAwareCPProxy(BaseCPProxy, abc.ABC): + def __init__(self, context, group_id, service_name, proxy_name, object_name): + super(SessionAwareCPProxy, self).__init__( + context, group_id, service_name, proxy_name, object_name + ) + self._session_manager = context.proxy_session_manager + + def get_group_id(self) -> RaftGroupId: + """ + Returns: + Id of the CP group that runs this proxy. + """ + return self._group_id + + def _get_session_id(self) -> int: + return self._session_manager.get_session_id(self._group_id) + + async def _acquire_session(self, count: int = 1) -> int: + return await self._session_manager.acquire_session(self._group_id, count) + + def _release_session(self, session_id: int, count: int = 1) -> None: + self._session_manager.release_session(self._group_id, session_id, count) + + def _invalidate_session(self, session_id: int) -> None: + self._session_manager.invalidate_session(self._group_id, session_id) + + +_NO_SESSION_ID = -1 + + +class ProxySessionManager: + def __init__(self, context): + self._context = context + self._mutexes = dict() # RaftGroupId to asyncio.Lock + self._sessions = dict() # RaftGroupId to SessionState + self._thread_ids = dict() # (RaftGroupId, thread_id) to global thread id + self._heartbeat_task = None + self._shutdown = False + self._lock = asyncio.Lock() + + def get_session_id(self, group_id): + session = self._sessions.get(group_id, None) + if session is None: + return _NO_SESSION_ID + return session.id + + async def acquire_session(self, group_id, count): + state = await self._get_or_create_session(group_id) + return state.acquire(count) + + def release_session(self, group_id, session_id, count): + session = self._sessions.get(group_id, None) + if session and session.id == session_id: + session.release(count) + + def invalidate_session(self, group_id, session_id): + session = self._sessions.get(group_id, None) + if session and session.id == session_id: + self._sessions.pop(group_id, None) + + async def get_or_create_unique_thread_id(self, group_id): + async with self._lock: + if self._shutdown: + raise HazelcastClientNotActiveError("Session manager is already shut down!") + + # TODO: replace 0 with the lock context once implemented + key = (group_id, 0) + global_thread_id = self._thread_ids.get(key) + if global_thread_id: + return global_thread_id + + tid = await self._request_generate_thread_id(group_id) + return self._thread_ids.setdefault(key, tid) + + async def shutdown(self): + async with self._lock: + if self._shutdown: + return None + + self._shutdown = True + if self._heartbeat_task: + self._heartbeat_task.cancel() + + tasks = [] + async with asyncio.TaskGroup() as tg: + for session in list(self._sessions.values()): + tasks.append( + tg.create_task(self._request_close_session(session.group_id, session.id)) + ) + + self._sessions.clear() + self._mutexes.clear() + self._thread_ids.clear() + + async def _request_generate_thread_id(self, group_id): + codec = cp_session_generate_thread_id_codec + request = codec.encode_request(group_id) + invocation = Invocation(request, response_handler=codec.decode_response) + return await self._context.invocation_service.ainvoke(invocation) + + async def _request_close_session(self, group_id, session_id): + codec = cp_session_close_session_codec + request = codec.encode_request(group_id, session_id) + invocation = Invocation(request, response_handler=codec.decode_response) + return await self._context.invocation_service.ainvoke(invocation) + + async def _get_or_create_session(self, group_id): + async with self._lock: + if self._shutdown: + raise HazelcastClientNotActiveError("Session manager is already shut down!") + + session = self._sessions.get(group_id, None) + if session is None or not session.is_valid(): + async with self._mutex(group_id): + session = self._sessions.get(group_id) + if session is None or not session.is_valid(): + return await self._create_new_session(group_id) + return session + + async def _create_new_session(self, group_id): + response = await self._request_new_session(group_id) + return self._do_create_new_session(response, group_id) + + def _do_create_new_session(self, response, group_id): + session = _SessionState(response["session_id"], group_id, response["ttl_millis"] / 1000.0) + self._sessions[group_id] = session + self._start_heartbeat_timer(response["heartbeat_millis"] / 1000.0) + return session + + async def _request_new_session(self, group_id): + codec = cp_session_create_session_codec + request = codec.encode_request(group_id, self._context.name) + invocation = Invocation(request, response_handler=codec.decode_response) + return await self._context.invocation_service.ainvoke(invocation) + + def _mutex(self, group_id) -> asyncio.Lock: + mutex = self._mutexes.get(group_id, None) + if mutex is not None: + return mutex + + mutex = asyncio.Lock() + current = self._mutexes.setdefault(group_id, mutex) + return current + + def _start_heartbeat_timer(self, period): + if self._heartbeat_task is not None: + return + + async def heartbeat(): + await asyncio.sleep(period) + if self._shutdown: + return + + for session in list(self._sessions.values()): + if session.is_in_use(): + + def cb(heartbeat_future: asyncio.Future, session=session): + error = heartbeat_future.exception() + if error is None: + return + + if isinstance(error, (SessionExpiredError, CPGroupDestroyedError)): + self.invalidate_session(session.group_id, session.id) + + f = self._request_heartbeat(session.group_id, session.id) + f.add_done_callback(cb) + + self._heartbeat_task = asyncio.create_task(heartbeat()) + + self._heartbeat_task = asyncio.create_task(heartbeat()) + + def _request_heartbeat(self, group_id, session_id) -> asyncio.Future: + codec = cp_session_heartbeat_session_codec + request = codec.encode_request(group_id, session_id) + invocation = Invocation(request) + self._context.invocation_service.invoke(invocation) + return invocation.future diff --git a/hazelcast/internal/asyncio_proxy/cp_manager.py b/hazelcast/internal/asyncio_proxy/cp_manager.py index 022427f324..8dfa8c2eba 100644 --- a/hazelcast/internal/asyncio_proxy/cp_manager.py +++ b/hazelcast/internal/asyncio_proxy/cp_manager.py @@ -4,12 +4,21 @@ ATOMIC_LONG_SERVICE, ATOMIC_REFERENCE_SERVICE, COUNT_DOWN_LATCH_SERVICE, + SEMAPHORE_SERVICE, ) from hazelcast.internal.asyncio_invocation import Invocation from hazelcast.internal.asyncio_proxy.atomic_long import AtomicLong from hazelcast.internal.asyncio_proxy.atomic_reference import AtomicReference from hazelcast.internal.asyncio_proxy.countdown_latch import CountDownLatch -from hazelcast.protocol.codec import cp_group_create_cp_group_codec +from hazelcast.internal.asyncio_proxy.semaphore import ( + Semaphore, + SessionAwareSemaphore, + SessionlessSemaphore, +) +from hazelcast.protocol.codec import ( + cp_group_create_cp_group_codec, + semaphore_get_semaphore_type_codec, +) class CPSubsystem: @@ -100,6 +109,25 @@ async def get_count_down_latch(self, name: str) -> CountDownLatch: """ return await self._proxy_manager.get_or_create(COUNT_DOWN_LATCH_SERVICE, name) + async def get_semaphore(self, name: str) -> Semaphore: + """Returns the distributed Semaphore instance with given name. + + The instance is created on CP Subsystem. + + If no group name is given within the ``name`` argument, then the + Semaphore instance will be created on the DEFAULT CP group. + If a group name is given, like ``.get_semaphore("mySemaphore@group1")``, + the given group will be initialized first, if not initialized + already, and then the instance will be created on this group. + + Args: + name: Name of the Semaphore + + Returns: + The Semaphore proxy for the given name. + """ + return await self._proxy_manager.get_or_create(SEMAPHORE_SERVICE, name) + class CPProxyManager: def __init__(self, context): @@ -116,9 +144,20 @@ async def get_or_create(self, service_name, proxy_name): return AtomicReference(self._context, group_id, service_name, proxy_name, object_name) elif service_name == COUNT_DOWN_LATCH_SERVICE: return CountDownLatch(self._context, group_id, service_name, proxy_name, object_name) + elif service_name == SEMAPHORE_SERVICE: + return await self._create_semaphore(group_id, proxy_name, object_name) raise ValueError("Unknown service name: %s" % service_name) + async def _create_semaphore(self, group_id, proxy_name, object_name): + codec = semaphore_get_semaphore_type_codec + request = codec.encode_request(proxy_name) + invocation = Invocation(request, response_handler=codec.decode_response) + invocation_service = self._context.invocation_service + jdk_compatible = await invocation_service.ainvoke(invocation) + kls = SessionlessSemaphore if jdk_compatible else SessionAwareSemaphore + return kls(self._context, group_id, SEMAPHORE_SERVICE, proxy_name, object_name) + async def _get_group_id(self, proxy_name): codec = cp_group_create_cp_group_codec request = codec.encode_request(proxy_name) diff --git a/hazelcast/internal/asyncio_proxy/semaphore.py b/hazelcast/internal/asyncio_proxy/semaphore.py new file mode 100644 index 0000000000..6d46a41f41 --- /dev/null +++ b/hazelcast/internal/asyncio_proxy/semaphore.py @@ -0,0 +1,539 @@ +import time +import uuid + +from hazelcast.errors import IllegalStateError, SessionExpiredError, WaitKeyCancelledError +from hazelcast.internal.asyncio_proxy.cp import BaseCPProxy, SessionAwareCPProxy, _NO_SESSION_ID +from hazelcast.protocol.codec import ( + semaphore_init_codec, + semaphore_available_permits_codec, + semaphore_acquire_codec, + semaphore_change_codec, + semaphore_drain_codec, + semaphore_release_codec, +) +from hazelcast.util import check_not_negative, check_true, to_millis + +# Since a proxy does not know how many permits will be drained on +# the Raft group, it uses this constant to increment its local session +# acquire count. Then, it adjusts the local session acquire count after +# the drain response is returned. +_DRAIN_SESSION_ACQ_COUNT = 1024 + + +class Semaphore(BaseCPProxy): + """A linearizable, distributed semaphore. + + Semaphores are often used to restrict the number of callers that can access + some physical or logical resource. + + Semaphore is a cluster-wide counting semaphore. Conceptually, it maintains + a set of permits. Each ``acquire()`` blocks if necessary until a permit + is available, and then takes it. Dually, each ``release()`` adds a + permit, potentially releasing a blocking acquirer. However, no actual permit + objects are used; the semaphore just keeps a count of the number available + and acts accordingly. + + Hazelcast's distributed semaphore implementation guarantees that callers + invoking any of the ``acquire()`` methods are selected to + obtain permits in the order of their invocations (first-in-first-out; FIFO). + Note that FIFO ordering implies the order which the primary replica of an + Semaphore receives these acquire requests. Therefore, it is + possible for one member to invoke ``acquire()`` before another member, + but its request hits the primary replica after the other member. + + This class also provides convenient ways to work with multiple permits at + once. Beware of the increased risk of indefinite postponement when using the + multiple-permit acquire. If permits are released one by one, a caller + waiting for one permit will acquire it before a caller waiting for multiple + permits regardless of the call order. + + Correct usage of a semaphore is established by programming convention + in the application. + + It works on top of the Raft consensus algorithm. It offers linearizability + during crash failures and network partitions. It is CP with respect to the + CAP principle. If a network partition occurs, it remains available on at + most one side of the partition. + + It has 2 variations: + + - The default implementation accessed via ``cp_subsystem`` is session-aware. + In this one, when a caller makes its very first ``acquire()`` call, it + starts a new CP session with the underlying CP group. Then, liveliness of + the caller is tracked via this CP session. When the caller fails, permits + acquired by this caller are automatically and safely released. However, + the session-aware version comes with a limitation, that is, a client + cannot release permits before acquiring them first. In other words, a + client can release only the permits it has acquired earlier. It means, you + can acquire a permit from one thread and release it from another thread + using the same Hazelcast client, but not different instances of Hazelcast + client. You can use the session-aware CP Semaphore implementation by + disabling JDK compatibility via ``jdk-compatible`` server-side setting. + Although the session-aware implementation has a minor difference to the + JDK Semaphore, we think it is a better fit for distributed environments + because of its safe auto-cleanup mechanism for acquired permits. + - The second implementation offered by ``cp_subsystem`` is sessionless. This + implementation does not perform auto-cleanup of acquired permits on + failures. Acquired permits are not bound to threads and permits can be + released without acquiring first. However, you need to handle failed + permit owners on your own. If a Hazelcast server or a client fails while + holding some permits, they will not be automatically released. You can + use the sessionless CP Semaphore implementation by enabling JDK + compatibility via ``jdk-compatible`` server-side setting. + + There is a subtle difference between the lock and semaphore abstractions. + A lock can be assigned to at most one endpoint at a time, so we have a total + order among its holders. However, permits of a semaphore can be assigned to + multiple endpoints at a time, which implies that we may not have a total + order among permit holders. In fact, permit holders are partially ordered. + For this reason, the fencing token approach, which is explained in + :class:`~hazelcast.proxy.cp.fenced_lock.FencedLock`, does not work for the + semaphore abstraction. Moreover, each permit is an independent entity. + Multiple permit acquires and reentrant lock acquires of a single endpoint + are not equivalent. The only case where a semaphore behaves like a lock is + the binary case, where the semaphore has only 1 permit. In this case, the + semaphore works like a non-reentrant lock. + + All of the API methods in the new CP Semaphore implementation offer + the exactly-once execution semantics for the session-aware version. + For instance, even if a ``release()`` call is internally retried + because of a crashed Hazelcast member, the permit is released only once. + However, this guarantee is not given for the sessionless, a.k.a, + JDK-compatible CP Semaphore. + """ + + async def init(self, permits: int) -> bool: + """Tries to initialize this Semaphore instance with the given permit + count. + + Args: + permits: The given permit count. + + Returns: + ``True`` if the initialization succeeds, ``False`` if already + initialized. + + Raises: + AssertionError: If the ``permits`` is negative. + """ + check_not_negative(permits, "Permits must be non-negative") + codec = semaphore_init_codec + request = codec.encode_request(self._group_id, self._object_name, permits) + return await self._ainvoke(request, codec.decode_response) + + async def acquire(self, permits: int = 1) -> None: + """Acquires the given number of permits if they are available, + and returns immediately, reducing the number of available permits + by the given amount. + + If insufficient permits are available then the result of the returned + future is not set until one of the following things happens: + + - Some other caller invokes one of the ``release`` + methods for this semaphore, the current caller is next to be assigned + permits and the number of available permits satisfies this request, + - This Semaphore instance is destroyed + + Args: + permits: Optional number of permits to acquire; defaults to ``1`` + when not specified + + Raises: + AssertionError: If the ``permits`` is not positive. + """ + raise NotImplementedError("acquire") + + async def available_permits(self) -> int: + """Returns the current number of permits currently available in this + semaphore. + + This method is typically used for debugging and testing purposes. + + Returns: + The number of permits available in this semaphore. + """ + codec = semaphore_available_permits_codec + request = codec.encode_request(self._group_id, self._object_name) + return await self._ainvoke(request, codec.decode_response) + + async def drain_permits(self) -> int: + """Acquires and returns all permits that are available at invocation + time. + + Returns: + The number of permits drained. + """ + raise NotImplementedError("drain_permits") + + async def reduce_permits(self, reduction: int) -> None: + """Reduces the number of available permits by the indicated amount. + + This method differs from ``acquire`` as it does not block until permits + become available. Similarly, if the caller has acquired some permits, + they are not released with this call. + + Args: + reduction: The number of permits to reduce. + + Raises: + AssertionError: If the ``reduction`` is negative. + """ + check_not_negative(reduction, "Reduction must be non-negative") + if reduction == 0: + return None + + return await self._do_change_permits(-reduction) + + async def increase_permits(self, increase: int) -> None: + """Increases the number of available permits by the indicated amount. + + If there are some callers waiting for permits to become available, they + will be notified. Moreover, if the caller has acquired some permits, + they are not released with this call. + + Args: + increase: The number of permits to increase. + + Raises: + AssertionError: If ``increase`` is negative. + """ + check_not_negative(increase, "Increase must be non-negative") + if increase == 0: + return None + + return await self._do_change_permits(increase) + + async def release(self, permits: int = 1) -> None: + """Releases the given number of permits and increases the number of + available permits by that amount. + + If some callers in the cluster are blocked for acquiring permits, + they will be notified. + + If the underlying Semaphore implementation is non-JDK-compatible + (configured via ``jdk-compatible`` server-side setting), then a + client can only release a permit which it has acquired before. + In other words, a client cannot release a permit without acquiring + it first. + + Otherwise, which means the underlying implementation is JDK compatible + (configured via ``jdk-compatible`` server-side setting), there is no + requirement that a client that releases a permit must have acquired + that permit by calling one of the ``acquire()`` methods. A client can + freely release a permit without acquiring it first. In this case, + correct usage of a semaphore is established by programming convention + in the application. + + Args: + permits: Optional number of permits to release; defaults to ``1`` + when not specified. + + Raises: + AssertionError: If the ``permits`` is not positive. + IllegalStateError: if the Semaphore is non-JDK-compatible and the + caller does not have a permit + """ + raise NotImplementedError("release") + + async def try_acquire(self, permits: int = 1, timeout: float = 0) -> bool: + """Acquires the given number of permits and returns ``True``, if they + become available during the given waiting time. + + If permits are acquired, the number of available permits in the + Semaphore instance is also reduced by the given amount. + + If no sufficient permits are available, then the result of the returned + future is not set until one of the following things happens: + + - Permits are released by other callers, the current caller is next to + be assigned permits and the number of available permits satisfies this + request + - The specified waiting time elapses + + Args: + permits: The number of permits to acquire; defaults to ``1`` when + not specified. + timeout: Optional timeout in seconds to wait for the permits; when + it's not specified the operation will return immediately after + the acquire attempt. + + Returns: + ``True`` if all permits were acquired, ``False`` if the waiting + time elapsed before all permits could be acquired + + Raises: + AssertionError: If the ``permits`` is not positive. + """ + raise NotImplementedError("try_acquire") + + async def _do_change_permits(self, permits): + raise NotImplementedError("_do_change_permits") + + +class SessionAwareSemaphore(Semaphore, SessionAwareCPProxy): + async def acquire(self, permits=1): + check_true(permits > 0, "Permits must be positive") + # TODO: replace 0 with the lock context once implemented: + current_thread_id = 0 + invocation_uuid = uuid.uuid4() + await self._do_acquire(current_thread_id, invocation_uuid, permits) + + async def drain_permits(self): + # TODO: replace 0 with the lock context once implemented: + current_thread_id = 0 + invocation_uuid = uuid.uuid4() + return await self._do_drain(current_thread_id, invocation_uuid) + + async def release(self, permits=1): + check_true(permits > 0, "Permits must be positive") + session_id = self._get_session_id() + if session_id == _NO_SESSION_ID: + raise self._new_illegal_state_error() + + # TODO: replace 0 with the lock context once implemented: + current_thread_id = 0 + invocation_uuid = uuid.uuid4() + + try: + await self._request_release(session_id, current_thread_id, invocation_uuid, permits) + except SessionExpiredError as e: + self._invalidate_session(session_id) + raise self._new_illegal_state_error(e) + finally: + self._release_session(session_id, permits) + + async def try_acquire(self, permits=1, timeout=0): + check_true(permits > 0, "Permits must be positive") + timeout = max(0.0, timeout) + # TODO: replace 0 with the lock context once implemented: + current_thread_id = 0 + invocation_uuid = uuid.uuid4() + return await self._do_try_acquire(current_thread_id, invocation_uuid, permits, timeout) + + async def _do_acquire(self, current_thread_id, invocation_uuid, permits): + async def do_acquire_once(session_id): + try: + await self._request_acquire( + session_id, current_thread_id, invocation_uuid, permits, -1 + ) + except SessionExpiredError: + self._invalidate_session(session_id) + return await self._do_acquire(current_thread_id, invocation_uuid, permits) + except WaitKeyCancelledError: + self._release_session(session_id, permits) + raise IllegalStateError( + 'Semaphore("%s") not acquired because the acquire call on the CP ' + "group is cancelled, possibly because of another indeterminate call " + "from the same thread." % self._object_name + ) + except Exception as e: + self._release_session(session_id, permits) + raise e + + session_id = await self._acquire_session(permits) + await do_acquire_once(session_id) + + async def _do_drain(self, current_thread_id, invocation_uuid): + async def do_drain_once(session_id): + try: + count = await self._request_drain(session_id, current_thread_id, invocation_uuid) + self._release_session(session_id, _DRAIN_SESSION_ACQ_COUNT - count) + return count + except SessionExpiredError: + self._invalidate_session(session_id) + return await self._do_drain(current_thread_id, invocation_uuid) + except Exception as e: + self._release_session(session_id, _DRAIN_SESSION_ACQ_COUNT) + raise e + + session_id = await self._acquire_session(_DRAIN_SESSION_ACQ_COUNT) + return await do_drain_once(session_id) + + async def _do_change_permits(self, delta): + # TODO: replace 0 with the lock context once implemented: + current_thread_id = 0 + invocation_uuid = uuid.uuid4() + + async def do_change_permits_once(session_id): + try: + await self._request_change(session_id, current_thread_id, invocation_uuid, delta) + except SessionExpiredError as e: + self._invalidate_session(session_id) + raise self._new_illegal_state_error(e) + finally: + self._release_session(session_id) + + session_id = await self._acquire_session() + await do_change_permits_once(session_id) + + async def _do_try_acquire(self, current_thread_id, invocation_uuid, permits, timeout): + start = time.time() + + async def do_try_acquire_once(session_id): + try: + acquired = await self._request_acquire( + session_id, current_thread_id, invocation_uuid, permits, timeout + ) + if not acquired: + self._release_session(session_id, permits) + return acquired + except SessionExpiredError: + self._invalidate_session(session_id) + remaining_timeout = timeout - (time.time() - start) + if remaining_timeout <= 0: + return False + return await self._do_try_acquire( + current_thread_id, invocation_uuid, permits, remaining_timeout + ) + except WaitKeyCancelledError: + self._release_session(session_id, permits) + return False + except Exception as e: + self._release_session(session_id, permits) + raise e + + session_id = await self._acquire_session(permits) + return await do_try_acquire_once(session_id) + + def _new_illegal_state_error(self, cause=None): + return IllegalStateError('Semaphore["%s"] has no valid session!' % self._object_name, cause) + + async def _request_acquire( + self, session_id, current_thread_id, invocation_uuid, permits, timeout + ): + codec = semaphore_acquire_codec + if timeout >= 0: + timeout = to_millis(timeout) + + request = codec.encode_request( + self._group_id, + self._object_name, + session_id, + current_thread_id, + invocation_uuid, + permits, + timeout, + ) + return await self._ainvoke(request, codec.decode_response) + + async def _request_drain(self, session_id, current_thread_id, invocation_uuid): + codec = semaphore_drain_codec + request = codec.encode_request( + self._group_id, self._object_name, session_id, current_thread_id, invocation_uuid + ) + return await self._ainvoke(request, codec.decode_response) + + async def _request_change(self, session_id, current_thread_id, invocation_uuid, delta): + codec = semaphore_change_codec + request = codec.encode_request( + self._group_id, self._object_name, session_id, current_thread_id, invocation_uuid, delta + ) + return await self._ainvoke(request) + + async def _request_release(self, session_id, current_thread_id, invocation_uuid, permits): + codec = semaphore_release_codec + request = codec.encode_request( + self._group_id, + self._object_name, + session_id, + current_thread_id, + invocation_uuid, + permits, + ) + return await self._ainvoke(request) + + +class SessionlessSemaphore(Semaphore): + def __init__(self, context, group_id, service_name, proxy_name, object_name): + super(SessionlessSemaphore, self).__init__( + context, group_id, service_name, proxy_name, object_name + ) + self._session_manager = context.proxy_session_manager + + async def acquire(self, permits=1): + check_true(permits > 0, "Permits must be positive") + tid = await self._get_thread_id() + await self._do_try_acquire(tid, permits, -1) + + async def drain_permits(self): + tid = await self._get_thread_id() + return await self._do_drain_permits(tid) + + async def release(self, permits=1): + check_true(permits > 0, "Permits must be positive") + invocation_uuid = uuid.uuid4() + tid = await self._get_thread_id() + return await self._request_release(tid, invocation_uuid, permits) + + async def try_acquire(self, permits=1, timeout=0): + check_true(permits > 0, "Permits must be positive") + timeout = max(0.0, timeout) + tid = await self._get_thread_id() + return await self._do_try_acquire(tid, permits, timeout) + + async def _do_try_acquire(self, global_thread_id, permits, timeout): + invocation_uuid = uuid.uuid4() + try: + return await self._request_acquire(global_thread_id, invocation_uuid, permits, timeout) + except WaitKeyCancelledError: + raise IllegalStateError( + 'Semaphore("%s") not acquired because the acquire call on the ' + "CP group is cancelled, possibly because of another indeterminate " + "call from the same thread." % self._object_name + ) + + async def _do_drain_permits(self, global_thread_id): + invocation_uuid = uuid.uuid4() + codec = semaphore_drain_codec + request = codec.encode_request( + self._group_id, self._object_name, _NO_SESSION_ID, global_thread_id, invocation_uuid + ) + return await self._ainvoke(request, codec.decode_response) + + async def _do_change_permits(self, permits): + invocation_uuid = uuid.uuid4() + tid = await self._get_thread_id() + return await self._request_change(tid, invocation_uuid, permits) + + async def _request_acquire(self, global_thread_id, invocation_uuid, permits, timeout): + codec = semaphore_acquire_codec + if timeout >= 0: + timeout = to_millis(timeout) + + request = codec.encode_request( + self._group_id, + self._object_name, + _NO_SESSION_ID, + global_thread_id, + invocation_uuid, + permits, + timeout, + ) + return await self._ainvoke(request, codec.decode_response) + + async def _request_change(self, global_thread_id, invocation_uuid, permits): + codec = semaphore_change_codec + request = codec.encode_request( + self._group_id, + self._object_name, + _NO_SESSION_ID, + global_thread_id, + invocation_uuid, + permits, + ) + return await self._ainvoke(request) + + async def _request_release(self, global_thread_id, invocation_uuid, permits): + codec = semaphore_release_codec + request = codec.encode_request( + self._group_id, + self._object_name, + _NO_SESSION_ID, + global_thread_id, + invocation_uuid, + permits, + ) + return await self._ainvoke(request) + + async def _get_thread_id(self): + return await self._session_manager.get_or_create_unique_thread_id(self._group_id) diff --git a/tests/integration/asyncio/proxy/semaphore_test.py b/tests/integration/asyncio/proxy/semaphore_test.py new file mode 100644 index 0000000000..5d2e36c938 --- /dev/null +++ b/tests/integration/asyncio/proxy/semaphore_test.py @@ -0,0 +1,271 @@ +import asyncio + +import pytest + +from hazelcast.errors import DistributedObjectDestroyedError, IllegalStateError +from hazelcast.internal.asyncio_client import HazelcastClient +from tests.integration.asyncio.base import CPTestCase +from tests.util import random_string, get_current_timestamp + +SEMAPHORE_TYPES = [ + "sessionless", + "sessionaware", +] + + +@pytest.mark.enterprise +class SemaphoreTest(CPTestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.semaphore = None + + async def asyncTearDown(self): + if self.semaphore: + self.semaphore.destroy() + await super().asyncTearDown() + + async def test_semaphore_in_another_group(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 1) + another_semaphore = await self.client.cp_subsystem.get_semaphore( + semaphore._proxy_name + "@another" + ) + self.assertEqual(1, await semaphore.available_permits()) + self.assertEqual(0, await another_semaphore.available_permits()) + await semaphore.acquire() + self.assertEqual(0, await semaphore.available_permits()) + self.assertEqual(0, await another_semaphore.available_permits()) + + async def test_use_after_destroy(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type) + await semaphore.destroy() + # the next destroy call should be ignored + await semaphore.destroy() + + try: + await semaphore.init(1) + except DistributedObjectDestroyedError: + pass + else: + self.fail("expected DistributedObjectDestroyedError to be raised") + + semaphore2 = await self.client.cp_subsystem.get_semaphore(semaphore._proxy_name) + + try: + await semaphore2.init(1) + except DistributedObjectDestroyedError: + pass + else: + self.fail("expected DistributedObjectDestroyedError to be raised") + + async def test_session_aware_semaphore_after_client_shutdown(self): + semaphore = await self.get_semaphore("sessionaware", 1) + another_client = await HazelcastClient.create_and_start(cluster_name=self.cluster.id) + another_semaphore = await another_client.cp_subsystem.get_semaphore(semaphore._proxy_name) + await another_semaphore.acquire(1) + self.assertEqual(0, await another_semaphore.available_permits()) + self.assertEqual(0, await semaphore.available_permits()) + await another_client.shutdown() + + async def assertion(): + self.assertEqual(1, await semaphore.available_permits()) + + await self.assertTrueEventually(assertion) + + async def test_init(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type) + self.assertEqual(0, await semaphore.available_permits()) + self.assertTrue(await semaphore.init(10)) + self.assertEqual(10, await semaphore.available_permits()) + + async def test_init_when_already_initialized(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type) + self.assertTrue(await semaphore.init(5)) + self.assertFalse(await semaphore.init(7)) + self.assertEqual(5, await semaphore.available_permits()) + + async def test_acquire(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 42) + self.assertIsNone(await semaphore.acquire(2)) + self.assertEqual(40, await semaphore.available_permits()) + self.assertIsNone(await semaphore.acquire()) + self.assertEqual(39, await semaphore.available_permits()) + + async def test_acquire_when_not_enough_permits(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 5) + f = asyncio.create_task(semaphore.acquire(10)) + self.assertFalse(f.done()) + await asyncio.sleep(2) + self.assertFalse(f.done()) + await semaphore.destroy() + + try: + await f + except DistributedObjectDestroyedError: + pass + else: + self.fail("expected DistributedObjectDestroyedError to be raised") + + # TODO: Implement test_acquire_blocks_until_someone_releases after lock context is implemented + # TODO: test_acquire_blocks_until_semaphore_is_destroyed after lock context is implemented + + async def test_available_permits(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type) + self.assertEqual(0, await semaphore.available_permits()) + await semaphore.init(5) + self.assertEqual(5, await semaphore.available_permits()) + await semaphore.acquire(3) + self.assertEqual(2, await semaphore.available_permits()) + + async def test_drain_permits(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 20) + await semaphore.acquire(5) + self.assertEqual(15, await semaphore.drain_permits()) + self.assertEqual(0, await semaphore.available_permits()) + + async def test_drain_permits_when_no_permits(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 0) + self.assertEqual(0, await semaphore.drain_permits()) + + async def test_reduce_permits(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 10) + self.assertIsNone(await semaphore.reduce_permits(5)) + self.assertEqual(5, await semaphore.available_permits()) + self.assertIsNone(await semaphore.reduce_permits(0)) + self.assertEqual(5, await semaphore.available_permits()) + + async def test_reduce_permits_on_negative_permits_counter_sessionless(self): + semaphore = await self.get_semaphore("sessionless", 10) + await semaphore.reduce_permits(15) + self.assertEqual(-5, await semaphore.available_permits()) + await semaphore.release(10) + self.assertEqual(5, await semaphore.available_permits()) + + async def test_reduce_permits_on_negative_permits_counter_juc_sessionless(self): + semaphore = await self.get_semaphore("sessionless", 0) + await semaphore.reduce_permits(100) + await semaphore.release(10) + self.assertEqual(-90, await semaphore.available_permits()) + self.assertEqual(-90, await semaphore.drain_permits()) + await semaphore.release(10) + self.assertEqual(10, await semaphore.available_permits()) + self.assertEqual(10, await semaphore.drain_permits()) + + async def test_reduce_permits_on_negative_permits_counter_session_aware(self): + semaphore = await self.get_semaphore("sessionaware", 10) + await semaphore.reduce_permits(15) + self.assertEqual(-5, await semaphore.available_permits()) + + async def test_reduce_permits_on_negative_permits_counter_juc_session_aware(self): + semaphore = await self.get_semaphore("sessionaware", 0) + await semaphore.reduce_permits(100) + self.assertEqual(-100, await semaphore.available_permits()) + self.assertEqual(-100, await semaphore.drain_permits()) + + async def test_increase_permits(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 10) + self.assertEqual(10, await semaphore.available_permits()) + self.assertIsNone(await semaphore.increase_permits(100)) + self.assertEqual(110, await semaphore.available_permits()) + self.assertIsNone(await semaphore.increase_permits(0)) + self.assertEqual(110, await semaphore.available_permits()) + + async def test_release(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 2) + await semaphore.acquire(2) + self.assertIsNone(await semaphore.release(2)) + self.assertEqual(2, await semaphore.available_permits()) + + async def test_release_when_acquired_by_another_client_sessionless(self): + semaphore = await self.get_semaphore("sessionless") + another_client = await HazelcastClient.create_and_start(cluster_name=self.cluster.id) + another_semaphore = await another_client.cp_subsystem.get_semaphore(semaphore._proxy_name) + self.assertTrue(await another_semaphore.init(1)) + await another_semaphore.acquire() + + try: + await semaphore.release(1) + self.assertEqual(1, await semaphore.available_permits()) + finally: + await another_client.shutdown() + + async def test_release_when_not_acquired_session_aware(self): + semaphore = await self.get_semaphore("sessionaware", 3) + await semaphore.acquire(1) + + try: + await semaphore.release(2) + except IllegalStateError: + pass + else: + self.fail("expected IllegalStateError to be raised") + + async def test_release_when_there_is_no_session_session_aware(self): + semaphore = await self.get_semaphore("sessionaware", 3) + + try: + await semaphore.release() + except IllegalStateError: + pass + else: + self.fail("expected IllegalStateError to be raised") + + async def test_try_acquire(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 5) + self.assertTrue(await semaphore.try_acquire()) + self.assertEqual(4, await semaphore.available_permits()) + + async def test_try_acquire_with_given_permits(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 5) + self.assertTrue(await semaphore.try_acquire(3)) + self.assertEqual(2, await semaphore.available_permits()) + + async def test_try_acquire_when_not_enough_permits(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 1) + self.assertFalse(await semaphore.try_acquire(2)) + self.assertEqual(1, await semaphore.available_permits()) + + async def test_try_acquire_when_not_enough_permits_with_timeout(self): + for semaphore_type in SEMAPHORE_TYPES: + with self.subTest(semaphore_type, semaphore_type=semaphore_type): + semaphore = await self.get_semaphore(semaphore_type, 1) + start = get_current_timestamp() + self.assertFalse(await semaphore.try_acquire(2, 1)) + self.assertGreaterEqual(get_current_timestamp() - start, 1) + self.assertEqual(1, await semaphore.available_permits()) + + async def get_semaphore(self, semaphore_type, initialize_with=None): + semaphore = await self.client.cp_subsystem.get_semaphore(semaphore_type + random_string()) + if initialize_with is not None: + await semaphore.init(initialize_with) + self.semaphore = semaphore + return semaphore diff --git a/tests/integration/backward_compatible/proxy/cp/semaphore_test.py b/tests/integration/backward_compatible/proxy/cp/semaphore_test.py index 0c59ea633d..cb48cf2afc 100644 --- a/tests/integration/backward_compatible/proxy/cp/semaphore_test.py +++ b/tests/integration/backward_compatible/proxy/cp/semaphore_test.py @@ -38,7 +38,7 @@ def test_semaphore_in_another_group(self, semaphore_type): self.assertEqual(0, another_semaphore.available_permits()) semaphore.acquire() self.assertEqual(0, semaphore.available_permits()) - self.assertEqual(0, semaphore.available_permits()) + self.assertEqual(0, another_semaphore.available_permits()) @parameterized.expand(SEMAPHORE_TYPES) def test_use_after_destroy(self, semaphore_type): @@ -258,7 +258,7 @@ def test_release_when_there_is_no_session_session_aware(self): semaphore.release() @parameterized.expand(SEMAPHORE_TYPES) - def test_test_try_acquire(self, semaphore_type): + def test_try_acquire(self, semaphore_type): semaphore = self.get_semaphore(semaphore_type, 5) self.assertTrue(semaphore.try_acquire()) self.assertEqual(4, semaphore.available_permits())