Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import abc


class AsyncAbstractObjectStream(abc.ABC):
"""
Class for both ReadObjectStream as well as WriteObjectStream.

Attributes will include
1. bucket_name
2. object_name
3. generation_number (if given)


"""

def __init__(self, bucket_name, object_name, generation_number=None):
super().__init__()
self.bucket_name = bucket_name
self.object_name = object_name
self.generation_number = generation_number

@abc.abstractmethod
async def open(self):
raise NotImplementedError("Subclasses should implement this method.")

@abc.abstractmethod
async def close(self):
raise NotImplementedError("Subclasses should implement this method.")

@abc.abstractmethod
async def send(self):
raise NotImplementedError("Subclasses should implement this method.")

@abc.abstractmethod
async def recv(self):
raise NotImplementedError("Subclasses should implement this method.")
158 changes: 158 additions & 0 deletions google/cloud/storage/_experimental/async_multi_range_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
Mrd_generic(bucket, obj,gen=None, read_handle=None)
mrd = Mrd(bucket, obj, gen)

mrd = Mrd(bucket, obj)
mrd = Mrd.create_from(client, bucket, obj)
Mrd_generic(bucket, obj,gen=None, read_handle=None)
* set attributes
* instantiate read_object_strea
* async stream.open
mrd = Mrd(read_handle)
mrd.download_ranges([(range_start, range_end, buf)])

mrr = await MultiRangeDownloader.create_mrd(client, bucket, obj)
await mrr.download_ranges([(range_start, range_end, buf)])


"""

from async_read_object_stream import AsyncReadObjectStream
from async_grpc_client import AsyncGrpcClient
from io import BytesIO
from google.cloud import _storage_v2
import sys
import asyncio
import uuid

_MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100


class MultiRangeDownloader:

@classmethod
async def create_mrd(cls, client, bucket_name, object_name, generation_number=None):
# inti
# async mrd.open()
mrd = cls(client, bucket_name, object_name, generation_number)
await mrd.open()
return mrd

@classmethod
def create_mrd_from_read_handle(cls, client, read_handle):
raise NotImplementedError("TODO")

def __init__(
self,
client,
bucket_name=None,
object_name=None,
generation_number=None,
read_handle=None, # open with rea
):
self.client = client
self.bucket_name = bucket_name
self.object_name = object_name
self.generation_number = generation_number
self.read_handle = read_handle

async def open(self):
self.read_obj_str = AsyncReadObjectStream(
client=self.client,
bucket_name=self.bucket_name,
object_name=self.object_name,
generation_number=self.generation_number,
read_handle=self.read_handle,
)
await self.read_obj_str.open()
if self.generation_number is None:
self.generation_number = self.read_obj_str.generation_number
self.read_handle = self.read_obj_str.read_handle
return

async def download_ranges(self, read_ranges):
"""
1.user can provide any number of ranges upto 1000.
2.


"""
if len(read_ranges) > 1000:
raise Exception("Invalid Input - ranges cannot be more than 1000")

read_id_to_writable_buffer_dict = {}
for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST):
read_range_segment = read_ranges[
i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
]

read_ranges_for_bidi_req = []
for j, read_range in enumerate(read_range_segment):
# generate read_id
read_id = i + j
read_id_to_writable_buffer_dict[read_id] = read_range[2]
read_ranges_for_bidi_req.append(
_storage_v2.ReadRange(
read_offset=read_range[0],
read_length=read_range[1] - read_range[0], # end - start
read_id=read_id,
)
)
print(read_ranges_for_bidi_req)
await self.read_obj_str.send(
_storage_v2.BidiReadObjectRequest(read_ranges=read_ranges_for_bidi_req)
)
while len(read_id_to_writable_buffer_dict) > 0:
response = await self.read_obj_str.recv()
if response is None:
print("None response received, something went wrong.")
sys.exit(1)
for object_data_range in response.object_data_ranges:

if object_data_range.read_range is None:
raise Exception("Invalid response, read_range is None")

data = object_data_range.checksummed_data.content
# bytes_received_in_curr_res = object_data_range.read_range.read_length
read_id = object_data_range.read_range.read_id
buffer = read_id_to_writable_buffer_dict[read_id]
buffer.write(data)
print(
"for read_id ",
read_id,
data,
object_data_range.checksummed_data.crc32c,
)
if object_data_range.range_end:
del read_id_to_writable_buffer_dict[
object_data_range.read_range.read_id
]
# print("downloaded bytes", bytes_received)

# pass


async def test_mrd():
client = AsyncGrpcClient()._grpc_client
mrd = await MultiRangeDownloader.create_mrd(
client, bucket_name="chandrasiri-rs", object_name="test_open9"
)
my_buff1 = BytesIO()
my_buff2 = BytesIO()
my_buff3 = BytesIO()
my_buff4 = BytesIO()
buffers = [my_buff1, my_buff2, my_buff3, my_buff4]
await mrd.download_ranges(
[
(0, 100, my_buff1),
(100, 200, my_buff2),
(200, 300, my_buff3),
(300, 400, my_buff4),
]
)
for buff in buffers:
print("downloaded bytes", buff.getbuffer().nbytes)


if __name__ == "__main__":
asyncio.run(test_mrd())
185 changes: 185 additions & 0 deletions google/cloud/storage/_experimental/async_read_object_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from google.cloud.storage._experimental.async_abstract_object_stream import (
AsyncAbstractObjectStream,
)
from bidi_async import AsyncBidiRpc
import asyncio
import argparse
from async_grpc_client import AsyncGrpcClient
from google.cloud import _storage_v2 as storage_v2
import random
from typing import List


"""
Mrr_generic(bucket, obj,gen=None, read_handle=None)
mrr = Mrr(bucket, obj, gen)

mrr = Mrr(bucket, obj)
mrr = Mrr.create_from(client, bucket, obj)
Mrr_generic(bucket, obj,gen=None, read_handle=None)
* set attributes
* instantiate read_object_strea
* async stream.open
mrr = Mrr(read_handle)
mrr.download_ranges([(range_start, range_end, buf)])

"""


class AsyncReadObjectStream(AsyncAbstractObjectStream):
def __init__(
self,
client,
bucket_name=None,
object_name=None,
generation_number=None,
# TODO: meta_generation
read_handle=None, # open with rea
):
super().__init__(
bucket_name=bucket_name,
object_name=object_name,
generation_number=generation_number,
)
self.client = client
self.bucket_name = bucket_name
self._full_bucket_name = f"projects/_/buckets/{bucket_name}"
self.object_name = object_name
self.generation_number = generation_number
self.read_handle = read_handle

# can this interface be changed tmrw ? (not accounting for that)
# self.rpc = self.client.get_bidi_rpc_str_str_mc() # expose this func in GAPIC
self.rpc = self.client._client._transport._wrapped_methods[
self.client._client._transport.bidi_read_object
]
first_bidi_read_req = storage_v2.BidiReadObjectRequest(
read_object_spec=storage_v2.BidiReadObjectSpec(
bucket=self._full_bucket_name, object=object_name
),
)
self.metadata = (("x-goog-request-params", f"bucket={self._full_bucket_name}"),)
self.socket_like_rpc = AsyncBidiRpc(
self.rpc, initial_request=first_bidi_read_req, metadata=self.metadata
)

async def open(self) -> None:
"""
1 send & 1 recv()

"""
await self.socket_like_rpc.open() # this is actually 1 send
response = await self.socket_like_rpc.recv()
print(response)
# object_metadata =
if self.generation_number is None:
self.generation_number = response.metadata.generation

self.read_handle = response.read_handle

return
# return await super().open()

async def close(self):
return await super().close()

async def send(self, bidi_read_object_request):
await self.socket_like_rpc.send(bidi_read_object_request)
"""
1. what if this fails ?
2. calculate checksum and send data
A: you don't have to calcuate checksum here. since it's read da! DF

"""

return

async def recv(self):
bidi_read_object_response = await self.socket_like_rpc.recv()
"""
P0 - get this working.
1. what if this fails ?
what kind of error ?
existing retry wrapper ? from gapic
2. data is already checksumm'ed ,
you calcuated the checksum , verify and return. If verification fails raise.

3. traces ?

4. what if decompressive transcoding ?


"""
return bidi_read_object_response


async def test(bucket_name, object_name):
client = AsyncGrpcClient()._grpc_client
async_read_obj_str = AsyncReadObjectStream(
client, bucket_name=bucket_name, object_name=object_name
)
await async_read_obj_str.open()

# create bidi proto 'n' requests
for i in range(3):
req_count = 10
read_range_count = 1

for req_num in range(req_count):
# create ranges
read_ranges: List[storage_v2.ReadRange] = []
read_ids_set = set()
for read_id in range(read_range_count):
read_ids_set.add(read_id)
# read_length = 32 * 1024 * 1024 # up to 32 MiB
read_length = 10
read_offset = random.randint(0, 210763776 - read_length)
# read_length = READ_LENGTH
# read_offset = random.randint(0, 10 * 1024 * 1024 - read_length)
read_range = storage_v2.ReadRange(
read_offset=read_offset, read_length=read_length, read_id=read_id
)
read_ranges.append(read_range)
# first bidi req is already sent, so in 2nd request onwards, we send only
# read_ranges.
await async_read_obj_str.send(
storage_v2.BidiReadObjectRequest(read_ranges=read_ranges)
)

for i in range(20):
print("i", i)
try:
# response2 = await asyncio.wait_for(async_read_obj_str.recv(), timeout=2)
async with asyncio.timeout(2):
# response2 = await asyncio.wait_for(
# async_read_obj_str.recv(), timeout=2
# )
response2 = await async_read_obj_str.recv()
print(response2)
except asyncio.TimeoutError:
print("await4ed for 2s no response")
# print("opening again")
# await async_read_obj_str.open()

break

# pass


if __name__ == "__main__":
"""
1. import argparse
2. create parser
3. add args

4. parse args
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--bucket_name", help="The name of the GCS bucket to upload to."
)
parser.add_argument("--object_name", help="Object name")
args = parser.parse_args()

asyncio.run(test(bucket_name=args.bucket_name, object_name=args.object_name))
# test()
Loading