Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions src/backend/common/database/cosmosdb.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from datetime import datetime, timezone
from typing import Dict, List, Optional
from uuid import UUID, uuid4

from azure.cosmos.aio import CosmosClient
from azure.cosmos.aio._database import DatabaseProxy
from azure.cosmos.exceptions import (
CosmosResourceExistsError
CosmosResourceExistsError,
CosmosResourceNotFoundError,
)

from common.database.database_base import DatabaseBase
Expand Down Expand Up @@ -85,9 +87,26 @@ async def create_batch(self, user_id: str, batch_id: UUID) -> BatchRecord:
await self.batch_container.create_item(body=batch.dict())
return batch
except CosmosResourceExistsError:
self.logger.info(f"Batch with ID {batch_id} already exists")
batchexists = await self.get_batch(user_id, str(batch_id))
return batchexists
self.logger.info("Batch already exists, reading existing record", batch_id=str(batch_id))
# Retry read with backoff to handle replication lag after 409 conflict
for attempt in range(3):
try:
batchexists = await self.batch_container.read_item(
item=str(batch_id), partition_key=str(batch_id)
)
if batchexists.get("user_id") != user_id:
self.logger.error("Batch belongs to a different user", batch_id=str(batch_id))
raise CosmosResourceNotFoundError(message="Batch not found")
self.logger.info("Returning existing batch record", batch_id=str(batch_id))
return BatchRecord.fromdb(batchexists)
except CosmosResourceNotFoundError:
if attempt < 2:
self.logger.info("Batch read returned 404 after conflict, retrying", batch_id=str(batch_id), attempt=attempt + 1)
await asyncio.sleep(0.5 * (attempt + 1))
else:
raise RuntimeError(
f"Batch {batch_id} already exists but could not be read after retries"
)

except Exception as e:
self.logger.error("Failed to create batch", error=str(e))
Expand Down Expand Up @@ -158,7 +177,7 @@ async def get_batch(self, user_id: str, batch_id: str) -> Optional[Dict]:
]
batch = None
async for item in self.batch_container.query_items(
query=query, parameters=params
query=query, parameters=params, partition_key=batch_id
):
batch = item

Expand All @@ -173,7 +192,7 @@ async def get_file(self, file_id: str) -> Optional[Dict]:
params = [{"name": "@file_id", "value": file_id}]
file_entry = None
async for item in self.file_container.query_items(
query=query, parameters=params
query=query, parameters=params, partition_key=file_id
):
file_entry = item
return file_entry
Expand Down Expand Up @@ -209,7 +228,7 @@ async def get_batch_from_id(self, batch_id: str) -> Dict:

batch = None # Store the batch
async for item in self.batch_container.query_items(
query=query, parameters=params
query=query, parameters=params, partition_key=batch_id
):
batch = item # Assign the batch to the variable

Expand Down Expand Up @@ -335,11 +354,14 @@ async def add_file_log(
raise

async def update_batch_entry(
self, batch_id: str, user_id: str, status: ProcessStatus, file_count: int
self, batch_id: str, user_id: str, status: ProcessStatus, file_count: int,
existing_batch: Optional[Dict] = None
):
"""Update batch status."""
"""Update batch status. If existing_batch is provided, skip the re-fetch."""
try:
batch = await self.get_batch(user_id, batch_id)
batch = existing_batch
if batch is None:
batch = await self.get_batch(user_id, batch_id)
if not batch:
raise ValueError("Batch not found")

Expand Down
37 changes: 26 additions & 11 deletions src/backend/common/database/database_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,40 @@

class DatabaseFactory:
_instance: Optional[DatabaseBase] = None
_lock: Optional[asyncio.Lock] = None
_logger = AppLogger("DatabaseFactory")

@staticmethod
def _get_lock() -> asyncio.Lock:
if DatabaseFactory._lock is None:
DatabaseFactory._lock = asyncio.Lock()
return DatabaseFactory._lock

@staticmethod
async def get_database():
if DatabaseFactory._instance is not None:
return DatabaseFactory._instance

async with DatabaseFactory._get_lock():
# Double-check after acquiring the lock
if DatabaseFactory._instance is not None:
return DatabaseFactory._instance
Comment thread
Pavan-Microsoft marked this conversation as resolved.

config = Config() # Create an instance of Config
config = Config() # Create an instance of Config

cosmos_db_client = CosmosDBClient(
endpoint=config.cosmosdb_endpoint,
credential=config.get_azure_credentials(),
database_name=config.cosmosdb_database,
batch_container=config.cosmosdb_batch_container,
file_container=config.cosmosdb_file_container,
log_container=config.cosmosdb_log_container,
)
cosmos_db_client = CosmosDBClient(
endpoint=config.cosmosdb_endpoint,
credential=config.get_azure_credentials(),
database_name=config.cosmosdb_database,
batch_container=config.cosmosdb_batch_container,
file_container=config.cosmosdb_file_container,
log_container=config.cosmosdb_log_container,
)

await cosmos_db_client.initialize_cosmos()
await cosmos_db_client.initialize_cosmos()

return cosmos_db_client
DatabaseFactory._instance = cosmos_db_client
return cosmos_db_client


# Local testing of config and code
Expand Down
9 changes: 6 additions & 3 deletions src/backend/common/services/batch_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,9 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi
)

# Create file entry
await self.database.add_file(batch_id, file_id, file.filename, blob_path)
file_record = await self.database.get_file(file_id)
file_record_obj = await self.database.add_file(batch_id, file_id, file.filename, blob_path)
file_record_dict = getattr(file_record_obj, "dict", None)
file_record = file_record_dict() if callable(file_record_dict) else file_record_obj

await self.database.add_file_log(
UUID(file_id),
Expand All @@ -307,6 +308,7 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi
user_id,
ProcessStatus.READY_TO_PROCESS,
batch["file_count"],
existing_batch=batch,
)
# Return response
return {"batch": batch, "file": file_record}
Expand All @@ -317,7 +319,8 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi
batch.file_count = len(files)
batch.updated_at = datetime.utcnow().isoformat()
await self.database.update_batch_entry(
batch_id, user_id, ProcessStatus.READY_TO_PROCESS, batch.file_count
batch_id, user_id, ProcessStatus.READY_TO_PROCESS, batch.file_count,
existing_batch=batch.dict(),
)
# Return response
return {"batch": batch, "file": file_record}
Expand Down
36 changes: 20 additions & 16 deletions src/tests/backend/common/database/cosmosdb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,22 @@ async def test_create_batch_exists(cosmos_db_client, mocker):
user_id = "user_1"
batch_id = uuid4()

# Mock container creation and get_batch
# Mock container creation and read_item
mock_batch_container = mock.MagicMock()
mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError)

# Mock the get_batch method
mock_get_batch = AsyncMock(return_value=BatchRecord(
batch_id=batch_id,
user_id=user_id,
file_count=0,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
status=ProcessStatus.READY_TO_PROCESS
))
mocker.patch.object(cosmos_db_client, 'get_batch', mock_get_batch)
# Mock read_item to return the existing batch record
existing_batch = {
"id": str(batch_id),
"batch_id": str(batch_id),
"user_id": user_id,
"file_count": 0,
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
"status": ProcessStatus.READY_TO_PROCESS,
}
mock_batch_container.read_item = AsyncMock(return_value=existing_batch)

# Call the method
batch = await cosmos_db_client.create_batch(user_id, batch_id)
Expand All @@ -182,7 +183,9 @@ async def test_create_batch_exists(cosmos_db_client, mocker):
assert batch.user_id == user_id
assert batch.status == ProcessStatus.READY_TO_PROCESS

mock_get_batch.assert_called_once_with(user_id, str(batch_id))
mock_batch_container.read_item.assert_called_once_with(
item=str(batch_id), partition_key=str(batch_id)
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -404,7 +407,7 @@ async def test_get_batch(cosmos_db_client, mocker):
}

# We define the async generator function that will yield the expected batch
async def mock_query_items(query, parameters):
async def mock_query_items(query, parameters, **kwargs):
yield expected_batch

# Assign the async generator to query_items mock
Expand All @@ -422,6 +425,7 @@ async def mock_query_items(query, parameters):
{"name": "@batch_id", "value": batch_id},
{"name": "@user_id", "value": user_id},
],
partition_key=batch_id,
)


Expand Down Expand Up @@ -468,8 +472,8 @@ async def test_get_file(cosmos_db_client, mocker):
"blob_path": "/path/to/file"
}

# We define the async generator function that will yield the expected batch
async def mock_query_items(query, parameters):
# We define the async generator function that will yield the expected file
async def mock_query_items(query, parameters, **kwargs):
yield expected_file

# Assign the async generator to query_items mock
Expand Down Expand Up @@ -594,7 +598,7 @@ async def test_get_batch_from_id(cosmos_db_client, mocker):
}

# Define the async generator function that will yield the expected batch
async def mock_query_items(query, parameters):
async def mock_query_items(query, parameters, **kwargs):
yield expected_batch

# Assign the async generator to query_items mock
Expand Down
Loading