11import logging
2- from typing import Any
2+ from typing import Any , Mapping
33
44import nanoid # type: ignore # type: ignore
55from pydantic import ValidationError
2727 InvalidMessageException ,
2828 SessionStateMismatchException ,
2929)
30+ from replit_river .server_session import ServerSession
3031from replit_river .session import Session
3132from replit_river .transport import Transport
3233from replit_river .transport_options import TransportOptions
3536
3637
3738class ServerTransport (Transport ):
39+ _sessions : dict [str , ServerSession ]
40+
3841 def __init__ (
3942 self ,
4043 transport_id : str ,
@@ -45,11 +48,12 @@ def __init__(
4548 transport_options = transport_options ,
4649 is_server = True ,
4750 )
51+ self ._sessions = {}
4852
4953 async def handshake_to_get_session (
5054 self ,
5155 websocket : WebSocketServerProtocol ,
52- ) -> Session :
56+ ) -> ServerSession :
5357 async for message in websocket :
5458 try :
5559 msg = parse_transport_msg (message , self ._transport_options )
@@ -88,23 +92,23 @@ async def handshake_to_get_session(
8892 raise WebsocketClosedException ("No handshake message received" )
8993
9094 async def close (self ) -> None :
91- await self ._close_all_sessions ()
95+ await self ._close_all_sessions (self . _get_all_sessions )
9296
9397 async def _get_or_create_session (
9498 self ,
9599 transport_id : str ,
96100 to_id : str ,
97101 session_id : str ,
98102 websocket : WebSocketCommonProtocol ,
99- ) -> Session :
103+ ) -> ServerSession :
100104 async with self ._session_lock :
101105 session_to_close : Session | None = None
102- new_session : Session | None = None
106+ new_session : ServerSession | None = None
103107 if to_id not in self ._sessions :
104108 logger .info (
105109 'Creating new session with "%s" using ws: %s' , to_id , websocket .id
106110 )
107- new_session = Session (
111+ new_session = ServerSession (
108112 transport_id ,
109113 to_id ,
110114 session_id ,
@@ -125,7 +129,7 @@ async def _get_or_create_session(
125129 old_session .session_id ,
126130 )
127131 session_to_close = old_session
128- new_session = Session (
132+ new_session = ServerSession (
129133 transport_id ,
130134 to_id ,
131135 session_id ,
@@ -152,7 +156,7 @@ async def _get_or_create_session(
152156 if session_to_close :
153157 logger .info ("Closing stale session %s" , session_to_close .session_id )
154158 await session_to_close .close ()
155- self ._set_session ( new_session )
159+ self ._sessions [ new_session . _to_id ] = new_session
156160 return new_session
157161
158162 async def _send_handshake_response (
@@ -293,3 +297,11 @@ async def _establish_handshake(
293297 )
294298
295299 return handshake_request , handshake_response
300+
301+ def _get_all_sessions (self ) -> Mapping [str , Session ]:
302+ return self ._sessions
303+
304+ async def _delete_session (self , session : Session ) -> None :
305+ async with self ._session_lock :
306+ if session ._to_id in self ._sessions :
307+ del self ._sessions [session ._to_id ]
0 commit comments