diff --git a/hazelcast/asyncio/__init__.py b/hazelcast/asyncio/__init__.py index cc94195cc0..98bb081f28 100644 --- a/hazelcast/asyncio/__init__.py +++ b/hazelcast/asyncio/__init__.py @@ -2,6 +2,7 @@ "AtomicLong", "AtomicReference", "CPSubsystem", + "CountDownLatch", "EntryEventCallable", "Executor", "HazelcastClient", @@ -33,3 +34,4 @@ from hazelcast.internal.asyncio_proxy.cp_manager import CPSubsystem from hazelcast.internal.asyncio_proxy.atomic_long import AtomicLong from hazelcast.internal.asyncio_proxy.atomic_reference import AtomicReference +from hazelcast.internal.asyncio_proxy.countdown_latch import CountDownLatch diff --git a/hazelcast/internal/asyncio_proxy/countdown_latch.py b/hazelcast/internal/asyncio_proxy/countdown_latch.py new file mode 100644 index 0000000000..28a7956adf --- /dev/null +++ b/hazelcast/internal/asyncio_proxy/countdown_latch.py @@ -0,0 +1,143 @@ +import uuid + +from hazelcast.errors import OperationTimeoutError +from hazelcast.internal.asyncio_proxy.cp import BaseCPProxy +from hazelcast.protocol.codec import ( + count_down_latch_await_codec, + count_down_latch_get_round_codec, + count_down_latch_count_down_codec, + count_down_latch_get_count_codec, + count_down_latch_try_set_count_codec, +) +from hazelcast.util import check_is_number, to_millis, check_is_int, check_true + + +class CountDownLatch(BaseCPProxy): + """A distributed, concurrent countdown latch data structure. + + CountDownLatch is a cluster-wide synchronization aid + that allows one or more callers to wait until a set of operations being + performed in other callers completes. + + CountDownLatch count can be reset using ``try_set_count()`` method after + a countdown has finished but not during an active count. This allows + the same latch instance to be reused. + + There is no ``await_latch()`` method to wait indefinitely since this is + undesirable in a distributed application: for example, a cluster can split + or the master and replicas could all terminate. In most cases, it is best + to configure an explicit timeout, so you have the ability to deal with + these situations. + + All the API methods in the CountDownLatch offer the exactly-once + execution semantics. For instance, even if a ``count_down()`` call is + internally retried because of crashed Hazelcast member, the counter + value is decremented only once. + """ + + async def await_latch(self, timeout: float) -> bool: + """Causes the current thread to wait until the latch has counted down to + zero, or an exception is thrown, or the specified waiting time elapses. + + If the current count is zero then this method returns ``True``. + + If the current count is greater than zero, then the current + thread becomes disabled for thread scheduling purposes and lies + dormant until one of the following things happen: + + - The count reaches zero due to invocations of the ``count_down()`` + method + - This CountDownLatch instance is destroyed + - The countdown owner becomes disconnected + - The specified waiting time elapses + + If the count reaches zero, then the method returns with the + value ``True``. + + If the specified waiting time elapses then the value ``False`` + is returned. If the time is less than or equal to zero, the method + will not wait at all. + + Args: + timeout: The maximum time to wait in seconds + + Returns: + ``True`` if the count reached zero, ``False`` if the waiting time + elapsed before the count reached zero + Raises: + IllegalStateError: If the Hazelcast instance was shut down while + waiting. + """ + check_is_number(timeout) + timeout = max(0.0, timeout) + invocation_uuid = uuid.uuid4() + codec = count_down_latch_await_codec + request = codec.encode_request( + self._group_id, self._object_name, invocation_uuid, to_millis(timeout) + ) + return await self._ainvoke(request, codec.decode_response) + + async def count_down(self) -> None: + """Decrements the count of the latch, releasing all waiting threads if + the count reaches zero. + + If the current count is greater than zero, then it is decremented. + If the new count is zero: + + - All waiting threads are re-enabled for thread scheduling purposes + - Countdown owner is set to ``None``. + + If the current count equals zero, then nothing happens. + """ + invocation_uuid = uuid.uuid4() + res = await self._get_round() + return await self._do_count_down(res, invocation_uuid) + + async def get_count(self) -> int: + """Returns the current count. + + Returns: + The current count. + """ + codec = count_down_latch_get_count_codec + request = codec.encode_request(self._group_id, self._object_name) + return await self._ainvoke(request, codec.decode_response) + + async def try_set_count(self, count: int) -> bool: + """Sets the count to the given value if the current count is zero. + + If count is not zero, then this method does nothing and returns + ``False``. + + Args: + count: The number of times ``count_down()`` must be invoked before + callers can pass through ``await_latch()``. + + Returns: + ``True`` if the new count was set, ``False`` if the current count + is not zero. + """ + check_is_int(count) + check_true(count > 0, "Count must be positive") + codec = count_down_latch_try_set_count_codec + request = codec.encode_request(self._group_id, self._object_name, count) + return await self._ainvoke(request, codec.decode_response) + + async def _do_count_down(self, expected_round, invocation_uuid): + try: + return await self._request_count_down(expected_round, invocation_uuid) + except OperationTimeoutError: + # we can retry safely because the retry is idempotent + return await self._do_count_down(expected_round, invocation_uuid) + + async def _get_round(self): + codec = count_down_latch_get_round_codec + request = codec.encode_request(self._group_id, self._object_name) + return await self._ainvoke(request, codec.decode_response) + + async def _request_count_down(self, expected_round, invocation_uuid): + codec = count_down_latch_count_down_codec + request = codec.encode_request( + self._group_id, self._object_name, invocation_uuid, expected_round + ) + return await self._ainvoke(request) diff --git a/hazelcast/internal/asyncio_proxy/cp_manager.py b/hazelcast/internal/asyncio_proxy/cp_manager.py index bf89cb1bc2..022427f324 100644 --- a/hazelcast/internal/asyncio_proxy/cp_manager.py +++ b/hazelcast/internal/asyncio_proxy/cp_manager.py @@ -3,10 +3,12 @@ _get_object_name_for_proxy, ATOMIC_LONG_SERVICE, ATOMIC_REFERENCE_SERVICE, + COUNT_DOWN_LATCH_SERVICE, ) from hazelcast.internal.asyncio_invocation import Invocation from hazelcast.internal.asyncio_proxy.atomic_long import AtomicLong from hazelcast.internal.asyncio_proxy.atomic_reference import AtomicReference +from hazelcast.internal.asyncio_proxy.countdown_latch import CountDownLatch from hazelcast.protocol.codec import cp_group_create_cp_group_codec @@ -78,6 +80,26 @@ async def get_atomic_reference(self, name: str) -> AtomicReference: """ return await self._proxy_manager.get_or_create(ATOMIC_REFERENCE_SERVICE, name) + async def get_count_down_latch(self, name: str) -> CountDownLatch: + """Returns the distributed CountDownLatch instance with given name. + + The instance is created on CP Subsystem. + + If no group name is given within the ``name`` argument, then the + CountDownLatch instance will be created on the DEFAULT CP group. + If a group name is given, like + ``.get_count_down_latch("myLatch@group1")``, the given group will be + initialized first, if not initialized already, and then the instance + will be created on this group. + + Args: + name: Name of the CountDownLatch. + + Returns: + The CountDownLatch proxy for the given name. + """ + return await self._proxy_manager.get_or_create(COUNT_DOWN_LATCH_SERVICE, name) + class CPProxyManager: def __init__(self, context): @@ -92,6 +114,8 @@ async def get_or_create(self, service_name, proxy_name): return AtomicLong(self._context, group_id, service_name, proxy_name, object_name) elif service_name == ATOMIC_REFERENCE_SERVICE: return AtomicReference(self._context, group_id, service_name, proxy_name, object_name) + elif service_name == COUNT_DOWN_LATCH_SERVICE: + return CountDownLatch(self._context, group_id, service_name, proxy_name, object_name) raise ValueError("Unknown service name: %s" % service_name) diff --git a/tests/integration/asyncio/proxy/countdown_latch_test.py b/tests/integration/asyncio/proxy/countdown_latch_test.py new file mode 100644 index 0000000000..74ec5f11bb --- /dev/null +++ b/tests/integration/asyncio/proxy/countdown_latch_test.py @@ -0,0 +1,158 @@ +import asyncio +import os + +import pytest + +from hazelcast.errors import DistributedObjectDestroyedError, OperationTimeoutError +from hazelcast.util import AtomicInteger +from tests.integration.asyncio.base import CPTestCase +from tests.util import random_string, get_current_timestamp + +inf = 2**31 - 1 + + +@pytest.mark.enterprise +class CountDownLatchTest(CPTestCase): + async def test_latch_in_another_group(self): + latch = await self.get_latch() + another_latch = await self.client.cp_subsystem.get_count_down_latch( + latch._proxy_name + "@another" + ) + await another_latch.try_set_count(42) + self.assertEqual(42, await another_latch.get_count()) + self.assertNotEqual(42, await latch.get_count()) + + async def test_use_after_destroy(self): + latch = await self.get_latch() + await latch.destroy() + # the next destroy call should be ignored + await latch.destroy() + + try: + await latch.get_count() + except DistributedObjectDestroyedError: + pass + else: + self.fail("expected DistributedObjectDestroyedError to be raised") + + latch2 = await self.client.cp_subsystem.get_count_down_latch(latch._proxy_name) + + try: + await latch2.get_count() + except DistributedObjectDestroyedError: + pass + else: + self.fail("expected DistributedObjectDestroyedError to be raised") + + async def test_await_latch_negative_timeout(self): + latch = await self.get_latch(1) + self.assertFalse(await latch.await_latch(-1)) + + async def test_await_latch_zero_timeout(self): + latch = await self.get_latch(1) + self.assertFalse(await latch.await_latch(0)) + + async def test_await_latch_with_timeout(self): + timeout = 1 + latch = await self.get_latch(1) + start = get_current_timestamp() + self.assertFalse(await latch.await_latch(timeout)) + time_passed = get_current_timestamp() - start + expected_time_passed = timeout + if os.name == "nt": + # On Windows, we were getting random test failures due to expected + # time passed being slightly less than the timeout. This is due to + # the low time resolution there (15-16ms). If we are on Windows, we + # lower our expectations and settle for a slightly lower value. + expected_time_passed *= 0.95 + + self.assertTrue( + time_passed >= expected_time_passed, + "Time passed is less than %s, which is %s" % (expected_time_passed, time_passed), + ) + + async def test_await_latch_multiple_waiters(self): + latch = await self.get_latch(1) + # TODO: replace the following with the asyncio variant when implemented + completed = AtomicInteger() + + async def run(): + await latch.await_latch(inf) + completed.get_and_increment() + + count = 10 + tasks = [] + for _ in range(count): + tasks.append(asyncio.create_task(run())) + + await latch.count_down() + + def assertion(): + self.assertEqual(count, completed.get()) + + await self.assertTrueEventually(assertion) + + async def test_await_latch_response_on_count_down(self): + latch = await self.get_latch() + self.assertTrue(await latch.await_latch(inf)) + self.assertTrue(await latch.try_set_count(1)) + # make a non-blocking request + future = asyncio.create_task(latch.await_latch(inf)) + asyncio.create_task(latch.count_down()) + self.assertTrue(await future) + + async def test_count_down(self): + latch = await self.get_latch(10) + + for i in range(9, -1, -1): + self.assertIsNone(await latch.count_down()) + self.assertEqual(i, await latch.get_count()) + + async def test_count_down_retry_on_timeout(self): + latch = await self.get_latch(1) + original = latch._request_count_down + # TODO: replace the following with the asyncio variant when implemented + called_count = AtomicInteger() + + async def mock(expected_round, invocation_uuid): + if called_count.get_and_increment() < 2: + raise OperationTimeoutError("xx") + return await original(expected_round, invocation_uuid) + + latch._request_count_down = mock + await latch.count_down() + # Will resolve on it's third call. First 2 throws timeout error + self.assertEqual(3, called_count.get()) + self.assertEqual(0, await latch.get_count()) + + async def test_get_count(self): + latch = await self.get_latch(1) + self.assertEqual(1, await latch.get_count()) + await latch.count_down() + self.assertEqual(0, await latch.get_count()) + await latch.try_set_count(10) + self.assertEqual(10, await latch.get_count()) + + async def test_try_set_count(self): + latch = await self.get_latch() + self.assertTrue(await latch.try_set_count(3)) + self.assertEqual(3, await latch.get_count()) + + async def test_try_set_count_when_count_is_already_set(self): + latch = await self.get_latch(1) + self.assertFalse(await latch.try_set_count(10)) + self.assertFalse(await latch.try_set_count(20)) + self.assertEqual(1, await latch.get_count()) + + async def test_try_set_count_when_count_goes_to_zero(self): + latch = await self.get_latch(1) + await latch.count_down() + self.assertEqual(0, await latch.get_count()) + self.assertTrue(await latch.try_set_count(3)) + self.assertEqual(3, await latch.get_count()) + + async def get_latch(self, initial_count=None): + latch = await self.client.cp_subsystem.get_count_down_latch("latch-" + random_string()) + if initial_count is not None: + self.assertTrue(await latch.try_set_count(initial_count)) + return latch