Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hazelcast/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"AtomicLong",
"AtomicReference",
"CPSubsystem",
"CountDownLatch",
"EntryEventCallable",
"Executor",
"HazelcastClient",
Expand Down Expand Up @@ -33,3 +34,4 @@
from hazelcast.internal.asyncio_proxy.cp_manager import CPSubsystem
from hazelcast.internal.asyncio_proxy.atomic_long import AtomicLong
from hazelcast.internal.asyncio_proxy.atomic_reference import AtomicReference
from hazelcast.internal.asyncio_proxy.countdown_latch import CountDownLatch
143 changes: 143 additions & 0 deletions hazelcast/internal/asyncio_proxy/countdown_latch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import uuid

from hazelcast.errors import OperationTimeoutError
from hazelcast.internal.asyncio_proxy.cp import BaseCPProxy
from hazelcast.protocol.codec import (
count_down_latch_await_codec,
count_down_latch_get_round_codec,
count_down_latch_count_down_codec,
count_down_latch_get_count_codec,
count_down_latch_try_set_count_codec,
)
from hazelcast.util import check_is_number, to_millis, check_is_int, check_true


class CountDownLatch(BaseCPProxy):
"""A distributed, concurrent countdown latch data structure.

CountDownLatch is a cluster-wide synchronization aid
that allows one or more callers to wait until a set of operations being
performed in other callers completes.

CountDownLatch count can be reset using ``try_set_count()`` method after
a countdown has finished but not during an active count. This allows
the same latch instance to be reused.

There is no ``await_latch()`` method to wait indefinitely since this is
undesirable in a distributed application: for example, a cluster can split
or the master and replicas could all terminate. In most cases, it is best
to configure an explicit timeout, so you have the ability to deal with
these situations.

All the API methods in the CountDownLatch offer the exactly-once
execution semantics. For instance, even if a ``count_down()`` call is
internally retried because of crashed Hazelcast member, the counter
value is decremented only once.
"""

async def await_latch(self, timeout: float) -> bool:
"""Causes the current thread to wait until the latch has counted down to
zero, or an exception is thrown, or the specified waiting time elapses.

If the current count is zero then this method returns ``True``.

If the current count is greater than zero, then the current
thread becomes disabled for thread scheduling purposes and lies
dormant until one of the following things happen:

- The count reaches zero due to invocations of the ``count_down()``
method
- This CountDownLatch instance is destroyed
- The countdown owner becomes disconnected
- The specified waiting time elapses

If the count reaches zero, then the method returns with the
value ``True``.

If the specified waiting time elapses then the value ``False``
is returned. If the time is less than or equal to zero, the method
will not wait at all.

Args:
timeout: The maximum time to wait in seconds

Returns:
``True`` if the count reached zero, ``False`` if the waiting time
elapsed before the count reached zero
Raises:
IllegalStateError: If the Hazelcast instance was shut down while
waiting.
"""
check_is_number(timeout)
timeout = max(0.0, timeout)
invocation_uuid = uuid.uuid4()
codec = count_down_latch_await_codec
request = codec.encode_request(
self._group_id, self._object_name, invocation_uuid, to_millis(timeout)
)
return await self._ainvoke(request, codec.decode_response)

async def count_down(self) -> None:
"""Decrements the count of the latch, releasing all waiting threads if
the count reaches zero.

If the current count is greater than zero, then it is decremented.
If the new count is zero:

- All waiting threads are re-enabled for thread scheduling purposes
- Countdown owner is set to ``None``.

If the current count equals zero, then nothing happens.
"""
invocation_uuid = uuid.uuid4()
res = await self._get_round()
return await self._do_count_down(res, invocation_uuid)

async def get_count(self) -> int:
"""Returns the current count.

Returns:
The current count.
"""
codec = count_down_latch_get_count_codec
request = codec.encode_request(self._group_id, self._object_name)
return await self._ainvoke(request, codec.decode_response)

async def try_set_count(self, count: int) -> bool:
"""Sets the count to the given value if the current count is zero.

If count is not zero, then this method does nothing and returns
``False``.

Args:
count: The number of times ``count_down()`` must be invoked before
callers can pass through ``await_latch()``.

Returns:
``True`` if the new count was set, ``False`` if the current count
is not zero.
"""
check_is_int(count)
check_true(count > 0, "Count must be positive")
codec = count_down_latch_try_set_count_codec
request = codec.encode_request(self._group_id, self._object_name, count)
return await self._ainvoke(request, codec.decode_response)

async def _do_count_down(self, expected_round, invocation_uuid):
try:
return await self._request_count_down(expected_round, invocation_uuid)
except OperationTimeoutError:
# we can retry safely because the retry is idempotent
return await self._do_count_down(expected_round, invocation_uuid)

async def _get_round(self):
codec = count_down_latch_get_round_codec
request = codec.encode_request(self._group_id, self._object_name)
return await self._ainvoke(request, codec.decode_response)

async def _request_count_down(self, expected_round, invocation_uuid):
codec = count_down_latch_count_down_codec
request = codec.encode_request(
self._group_id, self._object_name, invocation_uuid, expected_round
)
return await self._ainvoke(request)
24 changes: 24 additions & 0 deletions hazelcast/internal/asyncio_proxy/cp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
_get_object_name_for_proxy,
ATOMIC_LONG_SERVICE,
ATOMIC_REFERENCE_SERVICE,
COUNT_DOWN_LATCH_SERVICE,
)
from hazelcast.internal.asyncio_invocation import Invocation
from hazelcast.internal.asyncio_proxy.atomic_long import AtomicLong
from hazelcast.internal.asyncio_proxy.atomic_reference import AtomicReference
from hazelcast.internal.asyncio_proxy.countdown_latch import CountDownLatch
from hazelcast.protocol.codec import cp_group_create_cp_group_codec


Expand Down Expand Up @@ -78,6 +80,26 @@ async def get_atomic_reference(self, name: str) -> AtomicReference:
"""
return await self._proxy_manager.get_or_create(ATOMIC_REFERENCE_SERVICE, name)

async def get_count_down_latch(self, name: str) -> CountDownLatch:
"""Returns the distributed CountDownLatch instance with given name.

The instance is created on CP Subsystem.

If no group name is given within the ``name`` argument, then the
CountDownLatch instance will be created on the DEFAULT CP group.
If a group name is given, like
``.get_count_down_latch("myLatch@group1")``, the given group will be
initialized first, if not initialized already, and then the instance
will be created on this group.

Args:
name: Name of the CountDownLatch.

Returns:
The CountDownLatch proxy for the given name.
"""
return await self._proxy_manager.get_or_create(COUNT_DOWN_LATCH_SERVICE, name)


class CPProxyManager:
def __init__(self, context):
Expand All @@ -92,6 +114,8 @@ async def get_or_create(self, service_name, proxy_name):
return AtomicLong(self._context, group_id, service_name, proxy_name, object_name)
elif service_name == ATOMIC_REFERENCE_SERVICE:
return AtomicReference(self._context, group_id, service_name, proxy_name, object_name)
elif service_name == COUNT_DOWN_LATCH_SERVICE:
return CountDownLatch(self._context, group_id, service_name, proxy_name, object_name)

raise ValueError("Unknown service name: %s" % service_name)

Expand Down
158 changes: 158 additions & 0 deletions tests/integration/asyncio/proxy/countdown_latch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import asyncio
import os

import pytest

from hazelcast.errors import DistributedObjectDestroyedError, OperationTimeoutError
from hazelcast.util import AtomicInteger
from tests.integration.asyncio.base import CPTestCase
from tests.util import random_string, get_current_timestamp

inf = 2**31 - 1


@pytest.mark.enterprise
class CountDownLatchTest(CPTestCase):
async def test_latch_in_another_group(self):
latch = await self.get_latch()
another_latch = await self.client.cp_subsystem.get_count_down_latch(
latch._proxy_name + "@another"
)
await another_latch.try_set_count(42)
self.assertEqual(42, await another_latch.get_count())
self.assertNotEqual(42, await latch.get_count())

async def test_use_after_destroy(self):
latch = await self.get_latch()
await latch.destroy()
# the next destroy call should be ignored
await latch.destroy()

try:
await latch.get_count()
except DistributedObjectDestroyedError:
pass
else:
self.fail("expected DistributedObjectDestroyedError to be raised")

latch2 = await self.client.cp_subsystem.get_count_down_latch(latch._proxy_name)

try:
await latch2.get_count()
except DistributedObjectDestroyedError:
pass
else:
self.fail("expected DistributedObjectDestroyedError to be raised")

async def test_await_latch_negative_timeout(self):
latch = await self.get_latch(1)
self.assertFalse(await latch.await_latch(-1))

async def test_await_latch_zero_timeout(self):
latch = await self.get_latch(1)
self.assertFalse(await latch.await_latch(0))

async def test_await_latch_with_timeout(self):
timeout = 1
latch = await self.get_latch(1)
start = get_current_timestamp()
self.assertFalse(await latch.await_latch(timeout))
time_passed = get_current_timestamp() - start
expected_time_passed = timeout
if os.name == "nt":
# On Windows, we were getting random test failures due to expected
# time passed being slightly less than the timeout. This is due to
# the low time resolution there (15-16ms). If we are on Windows, we
# lower our expectations and settle for a slightly lower value.
expected_time_passed *= 0.95

self.assertTrue(
time_passed >= expected_time_passed,
"Time passed is less than %s, which is %s" % (expected_time_passed, time_passed),
)

async def test_await_latch_multiple_waiters(self):
latch = await self.get_latch(1)
# TODO: replace the following with the asyncio variant when implemented
completed = AtomicInteger()

async def run():
await latch.await_latch(inf)
completed.get_and_increment()

count = 10
tasks = []
for _ in range(count):
tasks.append(asyncio.create_task(run()))

await latch.count_down()

def assertion():
self.assertEqual(count, completed.get())

await self.assertTrueEventually(assertion)

async def test_await_latch_response_on_count_down(self):
latch = await self.get_latch()
self.assertTrue(await latch.await_latch(inf))
self.assertTrue(await latch.try_set_count(1))
# make a non-blocking request
future = asyncio.create_task(latch.await_latch(inf))
asyncio.create_task(latch.count_down())
self.assertTrue(await future)

async def test_count_down(self):
latch = await self.get_latch(10)

for i in range(9, -1, -1):
self.assertIsNone(await latch.count_down())
self.assertEqual(i, await latch.get_count())

async def test_count_down_retry_on_timeout(self):
latch = await self.get_latch(1)
original = latch._request_count_down
# TODO: replace the following with the asyncio variant when implemented
called_count = AtomicInteger()

async def mock(expected_round, invocation_uuid):
if called_count.get_and_increment() < 2:
raise OperationTimeoutError("xx")
return await original(expected_round, invocation_uuid)

latch._request_count_down = mock
await latch.count_down()
# Will resolve on it's third call. First 2 throws timeout error
self.assertEqual(3, called_count.get())
self.assertEqual(0, await latch.get_count())

async def test_get_count(self):
latch = await self.get_latch(1)
self.assertEqual(1, await latch.get_count())
await latch.count_down()
self.assertEqual(0, await latch.get_count())
await latch.try_set_count(10)
self.assertEqual(10, await latch.get_count())

async def test_try_set_count(self):
latch = await self.get_latch()
self.assertTrue(await latch.try_set_count(3))
self.assertEqual(3, await latch.get_count())

async def test_try_set_count_when_count_is_already_set(self):
latch = await self.get_latch(1)
self.assertFalse(await latch.try_set_count(10))
self.assertFalse(await latch.try_set_count(20))
self.assertEqual(1, await latch.get_count())

async def test_try_set_count_when_count_goes_to_zero(self):
latch = await self.get_latch(1)
await latch.count_down()
self.assertEqual(0, await latch.get_count())
self.assertTrue(await latch.try_set_count(3))
self.assertEqual(3, await latch.get_count())

async def get_latch(self, initial_count=None):
latch = await self.client.cp_subsystem.get_count_down_latch("latch-" + random_string())
if initial_count is not None:
self.assertTrue(await latch.try_set_count(initial_count))
return latch
Loading