diff --git a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py index bc60262ed..a458a5e43 100644 --- a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py +++ b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple from google.cloud.storage._experimental.asyncio.async_read_object_stream import ( _AsyncReadObjectStream, @@ -24,6 +24,38 @@ ) from io import BytesIO +from google.cloud import _storage_v2 + + +_MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100 + + +class Result: + """An instance of this class will be populated and retured for each + `read_range` provided to ``download_ranges`` method. + + """ + + def __init__(self, bytes_requested: int): + # only while instantiation, should not be edited later. + # hence there's no setter, only getter is provided. + self._bytes_requested: int = bytes_requested + self._bytes_written: int = 0 + + @property + def bytes_requested(self) -> int: + return self._bytes_requested + + @property + def bytes_written(self) -> int: + return self._bytes_written + + @bytes_written.setter + def bytes_written(self, value: int): + self._bytes_written = value + + def __repr__(self): + return f"bytes_requested: {self._bytes_requested}, bytes_written: {self._bytes_written}" class AsyncMultiRangeDownloader: @@ -38,21 +70,30 @@ class AsyncMultiRangeDownloader: mrd = await AsyncMultiRangeDownloader.create_mrd( client, bucket_name="chandrasiri-rs", object_name="test_open9" ) - my_buff1 = BytesIO() + my_buff1 = open('my_fav_file.txt', 'wb') my_buff2 = BytesIO() my_buff3 = BytesIO() - my_buff4 = BytesIO() - buffers = [my_buff1, my_buff2, my_buff3, my_buff4] - await mrd.download_ranges( + my_buff4 = any_object_which_provides_BytesIO_like_interface() + results_arr, error_obj = await mrd.download_ranges( [ + # (start_byte, bytes_to_read, writeable_buffer) (0, 100, my_buff1), - (100, 200, my_buff2), - (200, 300, my_buff3), - (300, 400, my_buff4), + (100, 20, my_buff2), + (200, 123, my_buff3), + (300, 789, my_buff4), ] ) - for buff in buffers: - print("downloaded bytes", buff.getbuffer().nbytes) + if error_obj: + print("Error occurred: ") + print(error_obj) + print( + "please issue call to `download_ranges` with updated" + "`read_ranges` based on diff of (bytes_requested - bytes_written)" + ) + + for result in results_arr: + print("downloaded bytes", result) + """ @@ -148,18 +189,70 @@ async def open(self) -> None: self.read_handle = self.read_obj_str.read_handle return - async def download_ranges(self, read_ranges: List[Tuple[int, int, BytesIO]]) -> Any: + async def download_ranges( + self, read_ranges: List[Tuple[int, int, BytesIO]] + ) -> List[Result]: """Downloads multiple byte ranges from the object into the buffers provided by user. :type read_ranges: List[Tuple[int, int, "BytesIO"]] :param read_ranges: A list of tuples, where each tuple represents a - byte range (start_byte, end_byte, buffer) to download. Buffer has to - be provided by the user, and user has to make sure appropriate + byte range (start_byte, bytes_to_read, writeable_buffer). Buffer has + to be provided by the user, and user has to make sure appropriate memory is available in the application to avoid out-of-memory crash. + :rtype: List[:class:`~google.cloud.storage._experimental.asyncio.async_multi_range_downloader.Result`] + :returns: A list of ``Result`` objects, where each object corresponds + to a requested range. - Raises: - NotImplementedError: This method is not yet implemented. """ - raise NotImplementedError("TODO") + if len(read_ranges) > 1000: + raise ValueError( + "Invalid input - length of read_ranges cannot be more than 1000" + ) + + read_id_to_writable_buffer_dict = {} + results = [] + for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST): + read_ranges_segment = read_ranges[ + i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST + ] + + read_ranges_for_bidi_req = [] + for j, read_range in enumerate(read_ranges_segment): + read_id = i + j + read_id_to_writable_buffer_dict[read_id] = read_range[2] + bytes_requested = read_range[1] + results.append(Result(bytes_requested)) + read_ranges_for_bidi_req.append( + _storage_v2.ReadRange( + read_offset=read_range[0], + read_length=bytes_requested, + read_id=read_id, + ) + ) + await self.read_obj_str.send( + _storage_v2.BidiReadObjectRequest(read_ranges=read_ranges_for_bidi_req) + ) + + while len(read_id_to_writable_buffer_dict) > 0: + response = await self.read_obj_str.recv() + + if response is None: + raise Exception("None response received, something went wrong.") + + for object_data_range in response.object_data_ranges: + if object_data_range.read_range is None: + raise Exception("Invalid response, read_range is None") + + data = object_data_range.checksummed_data.content + read_id = object_data_range.read_range.read_id + buffer = read_id_to_writable_buffer_dict[read_id] + buffer.write(data) + results[read_id].bytes_written += len(data) + + if object_data_range.range_end: + del read_id_to_writable_buffer_dict[ + object_data_range.read_range.read_id + ] + return results diff --git a/tests/unit/asyncio/test_async_multi_range_downloader.py b/tests/unit/asyncio/test_async_multi_range_downloader.py index edcd3fcc4..b57bc92ca 100644 --- a/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -15,6 +15,7 @@ import pytest from unittest import mock from unittest.mock import AsyncMock +from google.cloud import _storage_v2 from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import ( AsyncMultiRangeDownloader, @@ -28,52 +29,134 @@ _TEST_READ_HANDLE = b"test-handle" -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" -) -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" -) -@pytest.mark.asyncio -async def test_create_mrd(mock_async_grpc_client, async_read_object_stream): - # Arrange - mock_stream_instance = async_read_object_stream.return_value - mock_stream_instance.open = AsyncMock() - mock_stream_instance.generation_number = _TEST_GENERATION_NUMBER - mock_stream_instance.read_handle = _TEST_READ_HANDLE - - # act - mrd = await AsyncMultiRangeDownloader.create_mrd( - mock_async_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME - ) - - # Assert - async_read_object_stream.assert_called_once_with( - client=mock_async_grpc_client, +class TestAsyncMultiRangeDownloader: + # helper method + @pytest.mark.asyncio + async def _make_mock_mrd( + self, + mock_grpc_client, + mock_cls_async_read_object_stream, bucket_name=_TEST_BUCKET_NAME, object_name=_TEST_OBJECT_NAME, - generation_number=None, - read_handle=None, + generation_number=_TEST_GENERATION_NUMBER, + read_handle=_TEST_READ_HANDLE, + ): + mock_stream = mock_cls_async_read_object_stream.return_value + mock_stream.open = AsyncMock() + mock_stream.generation_number = _TEST_GENERATION_NUMBER + mock_stream.read_handle = _TEST_READ_HANDLE + + mrd = await AsyncMultiRangeDownloader.create_mrd( + mock_grpc_client, bucket_name, object_name, generation_number, read_handle + ) + + return mrd + + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" ) - mock_stream_instance.open.assert_called_once() + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + ) + @pytest.mark.asyncio + async def test_create_mrd( + self, mock_grpc_client, mock_cls_async_read_object_stream + ): + # Arrange & Act + mrd = await self._make_mock_mrd( + mock_grpc_client, mock_cls_async_read_object_stream + ) - assert mrd.client == mock_async_grpc_client - assert mrd.bucket_name == _TEST_BUCKET_NAME - assert mrd.object_name == _TEST_OBJECT_NAME - assert mrd.generation_number == _TEST_GENERATION_NUMBER - assert mrd.read_handle == _TEST_READ_HANDLE - assert mrd.read_obj_str is mock_stream_instance + # Assert + mock_cls_async_read_object_stream.assert_called_once_with( + client=mock_grpc_client, + bucket_name=_TEST_BUCKET_NAME, + object_name=_TEST_OBJECT_NAME, + generation_number=_TEST_GENERATION_NUMBER, + read_handle=_TEST_READ_HANDLE, + ) + mrd.read_obj_str.open.assert_called_once() -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" -) -@pytest.mark.asyncio -async def test_download_ranges(mock_async_grpc_client): - """Test that download_ranges() raises NotImplementedError.""" - mrd = AsyncMultiRangeDownloader( - mock_async_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME + assert mrd.client == mock_grpc_client + assert mrd.bucket_name == _TEST_BUCKET_NAME + assert mrd.object_name == _TEST_OBJECT_NAME + assert mrd.generation_number == _TEST_GENERATION_NUMBER + assert mrd.read_handle == _TEST_READ_HANDLE + + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" + ) + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + ) + @pytest.mark.asyncio + async def test_download_ranges( + self, mock_grpc_client, mock_cls_async_read_object_stream + ): + # Arrange + mock_mrd = await self._make_mock_mrd( + mock_grpc_client, mock_cls_async_read_object_stream + ) + mock_mrd.read_obj_str.send = AsyncMock() + mock_mrd.read_obj_str.recv = AsyncMock() + mock_mrd.read_obj_str.recv.return_value = _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=b"these_are_18_chars", crc32c=123 + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=0 + ), + ) + ], + ) + + # Act + buffer = BytesIO() + results = await mock_mrd.download_ranges([(0, 18, buffer)]) + + # Assert + mock_mrd.read_obj_str.send.assert_called_once_with( + _storage_v2.BidiReadObjectRequest( + read_ranges=[ + _storage_v2.ReadRange(read_offset=0, read_length=18, read_id=0) + ] + ) + ) + assert len(results) == 1 + assert results[0].bytes_requested == 18 + assert results[0].bytes_written == 18 + assert buffer.getvalue() == b"these_are_18_chars" + + def create_read_ranges(self, num_ranges): + ranges = [] + for i in range(num_ranges): + ranges.append( + _storage_v2.ReadRange(read_offset=i, read_length=1, read_id=i) + ) + return ranges + + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) + @pytest.mark.asyncio + async def test_downloading_ranges_with_more_than_1000_should_throw_error( + self, mock_grpc_client + ): + # Arrange + mrd = AsyncMultiRangeDownloader( + mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME + ) + + # Act + Assert + with pytest.raises(ValueError) as exc: + await mrd.download_ranges(self.create_read_ranges(1001)) - with pytest.raises(NotImplementedError): - await mrd.download_ranges([(0, 100, BytesIO())]) + # Assert + assert ( + str(exc.value) + == "Invalid input - length of read_ranges cannot be more than 1000" + )