Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hazelcast/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"ReliableTopic",
"ReplicatedMap",
"Ringbuffer",
"Semaphore",
"Set",
"VectorCollection",
]
Expand All @@ -35,3 +36,4 @@
from hazelcast.internal.asyncio_proxy.atomic_long import AtomicLong
from hazelcast.internal.asyncio_proxy.atomic_reference import AtomicReference
from hazelcast.internal.asyncio_proxy.countdown_latch import CountDownLatch
from hazelcast.internal.asyncio_proxy.semaphore import Semaphore
7 changes: 7 additions & 0 deletions hazelcast/internal/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from hazelcast.discovery import HazelcastCloudAddressProvider
from hazelcast.errors import IllegalStateError, InvalidConfigurationError
from hazelcast.internal.asyncio_invocation import InvocationService, Invocation
from hazelcast.internal.asyncio_proxy.cp import ProxySessionManager
from hazelcast.internal.asyncio_proxy.cp_manager import CPSubsystem
from hazelcast.internal.asyncio_proxy.pn_counter import PNCounter
from hazelcast.internal.asyncio_proxy.vector_collection import VectorCollection
Expand Down Expand Up @@ -209,6 +210,7 @@ def __init__(self, config: Config | None = None, **kwargs):
)
self._proxy_manager = ProxyManager(self._context)
self._cp_subsystem = CPSubsystem(self._context)
self._proxy_session_manager = ProxySessionManager(self._context)
self._lock_reference_id_generator = AtomicInteger(1)
self._statistics = Statistics(
self,
Expand Down Expand Up @@ -248,6 +250,7 @@ def _init_context(self):
self._near_cache_manager,
self._lock_reference_id_generator,
self._name,
self._proxy_session_manager,
self._reactor,
self._compact_schema_service,
)
Expand Down Expand Up @@ -493,6 +496,7 @@ async def shutdown(self) -> None:
if self._internal_lifecycle_service.running:
self._internal_lifecycle_service.fire_lifecycle_event(LifecycleState.SHUTTING_DOWN)
self._internal_lifecycle_service.shutdown()
await self._proxy_session_manager.shutdown()
self._near_cache_manager.destroy_near_caches()
await self._connection_manager.shutdown()
self._invocation_service.shutdown()
Expand Down Expand Up @@ -590,6 +594,7 @@ def __init__(self):
self.near_cache_manager = None
self.lock_reference_id_generator = None
self.name = None
self.proxy_session_manager = None
self.reactor = None
self.compact_schema_service = None

Expand All @@ -607,6 +612,7 @@ def init_context(
near_cache_manager,
lock_reference_id_generator,
name,
proxy_session_manager,
reactor,
compact_schema_service,
):
Expand All @@ -622,5 +628,6 @@ def init_context(
self.near_cache_manager = near_cache_manager
self.lock_reference_id_generator = lock_reference_id_generator
self.name = name
self.proxy_session_manager = proxy_session_manager
self.reactor = reactor
self.compact_schema_service = compact_schema_service
197 changes: 196 additions & 1 deletion hazelcast/internal/asyncio_proxy/cp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
import abc
import asyncio

from hazelcast.cp import _SessionState
from hazelcast.errors import (
HazelcastClientNotActiveError,
SessionExpiredError,
CPGroupDestroyedError,
)
from hazelcast.internal.asyncio_invocation import Invocation
from hazelcast.protocol.codec import cp_group_destroy_cp_object_codec
from hazelcast.protocol import RaftGroupId
from hazelcast.protocol.codec import (
cp_group_destroy_cp_object_codec,
cp_session_generate_thread_id_codec,
cp_session_close_session_codec,
cp_session_create_session_codec,
cp_session_heartbeat_session_codec,
)


def _no_op_response_handler(_):
Expand Down Expand Up @@ -32,3 +48,182 @@ def _invoke(self, request, response_handler=_no_op_response_handler):
async def _ainvoke(self, request, response_handler=_no_op_response_handler):
fut = self._invoke(request, response_handler)
return await fut


class SessionAwareCPProxy(BaseCPProxy, abc.ABC):
def __init__(self, context, group_id, service_name, proxy_name, object_name):
super(SessionAwareCPProxy, self).__init__(
context, group_id, service_name, proxy_name, object_name
)
self._session_manager = context.proxy_session_manager

def get_group_id(self) -> RaftGroupId:
"""
Returns:
Id of the CP group that runs this proxy.
"""
return self._group_id

def _get_session_id(self) -> int:
return self._session_manager.get_session_id(self._group_id)

async def _acquire_session(self, count: int = 1) -> int:
return await self._session_manager.acquire_session(self._group_id, count)

def _release_session(self, session_id: int, count: int = 1) -> None:
self._session_manager.release_session(self._group_id, session_id, count)

def _invalidate_session(self, session_id: int) -> None:
self._session_manager.invalidate_session(self._group_id, session_id)


_NO_SESSION_ID = -1


class ProxySessionManager:
def __init__(self, context):
self._context = context
self._mutexes = dict() # RaftGroupId to asyncio.Lock
self._sessions = dict() # RaftGroupId to SessionState
self._thread_ids = dict() # (RaftGroupId, thread_id) to global thread id
self._heartbeat_task = None
self._shutdown = False
self._lock = asyncio.Lock()

def get_session_id(self, group_id):
session = self._sessions.get(group_id, None)
if session is None:
return _NO_SESSION_ID
return session.id

async def acquire_session(self, group_id, count):
state = await self._get_or_create_session(group_id)
return state.acquire(count)

def release_session(self, group_id, session_id, count):
session = self._sessions.get(group_id, None)
if session and session.id == session_id:
session.release(count)

def invalidate_session(self, group_id, session_id):
session = self._sessions.get(group_id, None)
if session and session.id == session_id:
self._sessions.pop(group_id, None)

async def get_or_create_unique_thread_id(self, group_id):
async with self._lock:
if self._shutdown:
raise HazelcastClientNotActiveError("Session manager is already shut down!")

# TODO: replace 0 with the lock context once implemented
key = (group_id, 0)
global_thread_id = self._thread_ids.get(key)
if global_thread_id:
return global_thread_id

tid = await self._request_generate_thread_id(group_id)
return self._thread_ids.setdefault(key, tid)

async def shutdown(self):
async with self._lock:
if self._shutdown:
return None

self._shutdown = True
if self._heartbeat_task:
self._heartbeat_task.cancel()

tasks = []
async with asyncio.TaskGroup() as tg:
for session in list(self._sessions.values()):
tasks.append(
tg.create_task(self._request_close_session(session.group_id, session.id))
)

self._sessions.clear()
self._mutexes.clear()
self._thread_ids.clear()

async def _request_generate_thread_id(self, group_id):
codec = cp_session_generate_thread_id_codec
request = codec.encode_request(group_id)
invocation = Invocation(request, response_handler=codec.decode_response)
return await self._context.invocation_service.ainvoke(invocation)

async def _request_close_session(self, group_id, session_id):
codec = cp_session_close_session_codec
request = codec.encode_request(group_id, session_id)
invocation = Invocation(request, response_handler=codec.decode_response)
return await self._context.invocation_service.ainvoke(invocation)

async def _get_or_create_session(self, group_id):
async with self._lock:
if self._shutdown:
raise HazelcastClientNotActiveError("Session manager is already shut down!")

session = self._sessions.get(group_id, None)
if session is None or not session.is_valid():
async with self._mutex(group_id):
session = self._sessions.get(group_id)
if session is None or not session.is_valid():
return await self._create_new_session(group_id)
return session

async def _create_new_session(self, group_id):
response = await self._request_new_session(group_id)
return self._do_create_new_session(response, group_id)

def _do_create_new_session(self, response, group_id):
session = _SessionState(response["session_id"], group_id, response["ttl_millis"] / 1000.0)
self._sessions[group_id] = session
self._start_heartbeat_timer(response["heartbeat_millis"] / 1000.0)
return session

async def _request_new_session(self, group_id):
codec = cp_session_create_session_codec
request = codec.encode_request(group_id, self._context.name)
invocation = Invocation(request, response_handler=codec.decode_response)
return await self._context.invocation_service.ainvoke(invocation)

def _mutex(self, group_id) -> asyncio.Lock:
mutex = self._mutexes.get(group_id, None)
if mutex is not None:
return mutex

mutex = asyncio.Lock()
current = self._mutexes.setdefault(group_id, mutex)
return current

def _start_heartbeat_timer(self, period):
if self._heartbeat_task is not None:
return

async def heartbeat():
await asyncio.sleep(period)
if self._shutdown:
return

for session in list(self._sessions.values()):
if session.is_in_use():

def cb(heartbeat_future: asyncio.Future, session=session):
error = heartbeat_future.exception()
if error is None:
return

if isinstance(error, (SessionExpiredError, CPGroupDestroyedError)):
self.invalidate_session(session.group_id, session.id)

f = self._request_heartbeat(session.group_id, session.id)
f.add_done_callback(cb)

self._heartbeat_task = asyncio.create_task(heartbeat())

self._heartbeat_task = asyncio.create_task(heartbeat())

def _request_heartbeat(self, group_id, session_id) -> asyncio.Future:
codec = cp_session_heartbeat_session_codec
request = codec.encode_request(group_id, session_id)
invocation = Invocation(request)
self._context.invocation_service.invoke(invocation)
return invocation.future
41 changes: 40 additions & 1 deletion hazelcast/internal/asyncio_proxy/cp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,21 @@
ATOMIC_LONG_SERVICE,
ATOMIC_REFERENCE_SERVICE,
COUNT_DOWN_LATCH_SERVICE,
SEMAPHORE_SERVICE,
)
from hazelcast.internal.asyncio_invocation import Invocation
from hazelcast.internal.asyncio_proxy.atomic_long import AtomicLong
from hazelcast.internal.asyncio_proxy.atomic_reference import AtomicReference
from hazelcast.internal.asyncio_proxy.countdown_latch import CountDownLatch
from hazelcast.protocol.codec import cp_group_create_cp_group_codec
from hazelcast.internal.asyncio_proxy.semaphore import (
Semaphore,
SessionAwareSemaphore,
SessionlessSemaphore,
)
from hazelcast.protocol.codec import (
cp_group_create_cp_group_codec,
semaphore_get_semaphore_type_codec,
)


class CPSubsystem:
Expand Down Expand Up @@ -100,6 +109,25 @@ async def get_count_down_latch(self, name: str) -> CountDownLatch:
"""
return await self._proxy_manager.get_or_create(COUNT_DOWN_LATCH_SERVICE, name)

async def get_semaphore(self, name: str) -> Semaphore:
"""Returns the distributed Semaphore instance with given name.

The instance is created on CP Subsystem.

If no group name is given within the ``name`` argument, then the
Semaphore instance will be created on the DEFAULT CP group.
If a group name is given, like ``.get_semaphore("mySemaphore@group1")``,
the given group will be initialized first, if not initialized
already, and then the instance will be created on this group.

Args:
name: Name of the Semaphore

Returns:
The Semaphore proxy for the given name.
"""
return await self._proxy_manager.get_or_create(SEMAPHORE_SERVICE, name)


class CPProxyManager:
def __init__(self, context):
Expand All @@ -116,9 +144,20 @@ async def get_or_create(self, service_name, proxy_name):
return AtomicReference(self._context, group_id, service_name, proxy_name, object_name)
elif service_name == COUNT_DOWN_LATCH_SERVICE:
return CountDownLatch(self._context, group_id, service_name, proxy_name, object_name)
elif service_name == SEMAPHORE_SERVICE:
return await self._create_semaphore(group_id, proxy_name, object_name)

raise ValueError("Unknown service name: %s" % service_name)

async def _create_semaphore(self, group_id, proxy_name, object_name):
codec = semaphore_get_semaphore_type_codec
request = codec.encode_request(proxy_name)
invocation = Invocation(request, response_handler=codec.decode_response)
invocation_service = self._context.invocation_service
jdk_compatible = await invocation_service.ainvoke(invocation)
kls = SessionlessSemaphore if jdk_compatible else SessionAwareSemaphore
return kls(self._context, group_id, SEMAPHORE_SERVICE, proxy_name, object_name)

async def _get_group_id(self, proxy_name):
codec = cp_group_create_cp_group_codec
request = codec.encode_request(proxy_name)
Expand Down
Loading
Loading