Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from google.cloud import bigquery
from packaging.version import parse as parse_version
from pydantic import BaseModel
from requests.adapters import HTTPAdapter, Retry
from sqlalchemy.engine import URL, Connection, create_engine, make_url
from sqlalchemy.exc import ResourceClosedError

Expand Down Expand Up @@ -263,13 +264,28 @@ class ExecuteSqlError(Exception):
)


def _generate_temporary_credentials(integration_id):
def _create_retry_session() -> requests.Session:
"""Create a requests session with retry on 5xx for POST requests."""
session = requests.Session()
retries = Retry(
total=3,
backoff_factor=0.5,
status_forcelist=[500, 502, 503, 504],
allowed_methods=["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "TRACE"],
)
session.mount("http://", HTTPAdapter(max_retries=retries))
session.mount("https://", HTTPAdapter(max_retries=retries))
return session


def _generate_temporary_credentials(integration_id) -> tuple[str, str]:
url = get_absolute_userpod_api_url(f"integrations/credentials/{integration_id}")

# Add project credentials in detached mode
headers = get_project_auth_headers()

response = requests.post(url, timeout=10, headers=headers)
session = _create_retry_session()
response = session.post(url, timeout=10, headers=headers)

response.raise_for_status()

Expand All @@ -291,7 +307,8 @@ def _get_federated_auth_credentials(
headers = get_project_auth_headers()
headers["UserPodAuthContextToken"] = user_pod_auth_context_token

response = requests.post(url, timeout=10, headers=headers)
session = _create_retry_session()
response = session.post(url, timeout=10, headers=headers)

response.raise_for_status()

Expand Down
248 changes: 244 additions & 4 deletions tests/unit/test_sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,9 @@ def test_all_dataframes_serialize_to_parquet(self, key, df):
class TestFederatedAuth(unittest.TestCase):
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
@mock.patch("deepnote_toolkit.sql.sql_execution.requests.post")
@mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session")
def test_get_federated_auth_credentials_returns_validated_response(
self, mock_post, mock_get_url, mock_get_headers
self, mock_create_session, mock_get_url, mock_get_headers
):
"""Test that _get_federated_auth_credentials properly validates and returns response data."""
from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials
Expand All @@ -603,12 +603,14 @@ def test_get_federated_auth_credentials_returns_validated_response(
mock_get_url.return_value = "https://api.example.com/integrations/federated-auth-token/test-integration-id"
mock_get_headers.return_value = {"Authorization": "Bearer project-token"}

mock_session = mock.Mock()
mock_response = mock.Mock()
mock_response.json.return_value = {
"integrationType": "trino",
"accessToken": "test-access-token-123",
}
mock_post.return_value = mock_response
mock_session.post.return_value = mock_response
mock_create_session.return_value = mock_session

# Call the function
result = _get_federated_auth_credentials(
Expand All @@ -621,7 +623,7 @@ def test_get_federated_auth_credentials_returns_validated_response(
)

# Verify headers include both project auth and user pod auth context token
mock_post.assert_called_once_with(
mock_session.post.assert_called_once_with(
"https://api.example.com/integrations/federated-auth-token/test-integration-id",
timeout=10,
headers={
Expand Down Expand Up @@ -1019,3 +1021,241 @@ def test_databricks_connector_dialect_alias_is_registered(self):

self.assertEqual(url.drivername, "databricks+connector")
self.assertIsNotNone(dialect_cls)


class TestCreateRetrySession(unittest.TestCase):
"""Tests that exercise the real urllib3 retry loop by mocking at the
connection level (``HTTPConnectionPool._make_request``) rather than
replacing ``_create_retry_session``. This lets the ``Retry`` adapter
actually fire retries on 5xx responses.
"""

def test_create_retry_session_configuration(self):
"""Verify the retry session is wired with the expected parameters."""
from deepnote_toolkit.sql.sql_execution import _create_retry_session

session = _create_retry_session()

for prefix in ("http://", "https://"):
adapter = session.get_adapter(f"{prefix}example.com")
retry = adapter.max_retries

self.assertEqual(retry.total, 3)
self.assertEqual(retry.backoff_factor, 0.5)
self.assertEqual(set(retry.status_forcelist), {500, 502, 503, 504})
self.assertIn("POST", retry.allowed_methods)

# -- _generate_temporary_credentials ------------------------------------

@mock.patch("urllib3.util.retry.Retry.sleep", return_value=None)
@mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_generate_credentials_retries_on_5xx_then_succeeds(
self,
mock_get_url,
mock_get_headers,
mock_make_request,
mock_retry_sleep,
):
"""Two 5xx failures followed by a 200 - the retry loop must
transparently retry and ultimately return valid credentials."""
from urllib3 import HTTPResponse as Urllib3Response

from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/credentials/test-id"
)
mock_get_headers.return_value = {"Authorization": "Bearer token"}

success_body = json.dumps({"username": "user", "password": "pass"}).encode()
mock_make_request.side_effect = [
Urllib3Response(
body=io.BytesIO(b"Internal Server Error"),
status=500,
headers={},
preload_content=False,
),
Urllib3Response(
body=io.BytesIO(b"Bad Gateway"),
status=502,
headers={},
preload_content=False,
),
Urllib3Response(
body=io.BytesIO(success_body),
status=200,
headers={"Content-Type": "application/json"},
preload_content=False,
),
]

result = _generate_temporary_credentials("test-id")

self.assertEqual(result, ("user", "pass"))
self.assertEqual(mock_make_request.call_count, 3)
self.assertEqual(mock_retry_sleep.call_count, 2)

@mock.patch("urllib3.util.retry.Retry.sleep", return_value=None)
@mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_generate_credentials_exhausts_retries_on_persistent_5xx(
self,
mock_get_url,
mock_get_headers,
mock_make_request,
mock_retry_sleep,
):
"""All 4 attempts (1 original + 3 retries) return 500 -
must raise ``RetryError``."""
import requests
from urllib3 import HTTPResponse as Urllib3Response

from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/credentials/test-id"
)
mock_get_headers.return_value = {"Authorization": "Bearer token"}

mock_make_request.side_effect = [
Urllib3Response(
body=io.BytesIO(b"Server Error"),
status=500,
headers={},
preload_content=False,
)
for _ in range(4)
]

with self.assertRaises(requests.exceptions.RetryError):
_generate_temporary_credentials("test-id")

self.assertEqual(mock_make_request.call_count, 4)
self.assertEqual(mock_retry_sleep.call_count, 3)

@mock.patch("urllib3.util.retry.Retry.sleep", return_value=None)
@mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_generate_credentials_no_retry_on_4xx(
self,
mock_get_url,
mock_get_headers,
mock_make_request,
mock_retry_sleep,
):
"""A 400 is not in the retry status list - must fail immediately
without retrying."""
import requests
from urllib3 import HTTPResponse as Urllib3Response

from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/credentials/test-id"
)
mock_get_headers.return_value = {"Authorization": "Bearer token"}

mock_make_request.side_effect = [
Urllib3Response(
body=io.BytesIO(b"Bad Request"),
status=400,
headers={},
preload_content=False,
),
]

with self.assertRaises(requests.exceptions.HTTPError):
_generate_temporary_credentials("test-id")

self.assertEqual(mock_make_request.call_count, 1)
mock_retry_sleep.assert_not_called()

# -- _get_federated_auth_credentials ------------------------------------

@mock.patch("urllib3.util.retry.Retry.sleep", return_value=None)
@mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_federated_auth_retries_on_5xx_then_succeeds(
self,
mock_get_url,
mock_get_headers,
mock_make_request,
mock_retry_sleep,
):
"""A 503 followed by a 200 - retry loop must recover and return
valid ``FederatedAuthResponseData``."""
from urllib3 import HTTPResponse as Urllib3Response

from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/federated-auth-token/test-id"
)
mock_get_headers.return_value = {"Authorization": "Bearer token"}

success_body = json.dumps(
{"integrationType": "trino", "accessToken": "test-token"}
).encode()
mock_make_request.side_effect = [
Urllib3Response(
body=io.BytesIO(b"Service Unavailable"),
status=503,
headers={},
preload_content=False,
),
Urllib3Response(
body=io.BytesIO(success_body),
status=200,
headers={"Content-Type": "application/json"},
preload_content=False,
),
]

result = _get_federated_auth_credentials("test-id", "auth-context-token")

self.assertEqual(result.integrationType, "trino")
self.assertEqual(result.accessToken, "test-token")
self.assertEqual(mock_make_request.call_count, 2)
self.assertEqual(mock_retry_sleep.call_count, 1)

@mock.patch("urllib3.util.retry.Retry.sleep", return_value=None)
@mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_federated_auth_exhausts_retries_on_persistent_5xx(
self,
mock_get_url,
mock_get_headers,
mock_make_request,
mock_retry_sleep,
):
"""All 4 attempts return 504 - must raise ``RetryError``."""
import requests
from urllib3 import HTTPResponse as Urllib3Response

from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/federated-auth-token/test-id"
)
mock_get_headers.return_value = {"Authorization": "Bearer token"}

mock_make_request.side_effect = [
Urllib3Response(
body=io.BytesIO(b"Gateway Timeout"),
status=504,
headers={},
preload_content=False,
)
for _ in range(4)
]

with self.assertRaises(requests.exceptions.RetryError):
_get_federated_auth_credentials("test-id", "auth-context-token")

self.assertEqual(mock_make_request.call_count, 4)
Comment on lines +1226 to +1261
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing assertion on mock_retry_sleep.

mock_retry_sleep is declared but unused. The parallel test for _generate_temporary_credentials asserts mock_retry_sleep.call_count == 3. Add the same assertion here.

🔧 Proposed fix
         with self.assertRaises(requests.exceptions.RetryError):
             _get_federated_auth_credentials("test-id", "auth-context-token")

         self.assertEqual(mock_make_request.call_count, 4)
+        self.assertEqual(mock_retry_sleep.call_count, 3)
🧰 Tools
🪛 Ruff (0.15.9)

[warning] 1235-1235: Unused method argument: mock_retry_sleep

(ARG002)


[warning] 1258-1258: Use pytest.raises instead of unittest-style assertRaises

Replace assertRaises with pytest.raises

(PT027)


[warning] 1261-1261: Use a regular assert instead of unittest-style assertEqual

Replace assertEqual(...) with assert ...

(PT009)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/test_sql_execution.py` around lines 1226 - 1261, The test
test_federated_auth_exhausts_retries_on_persistent_5xx declares mock_retry_sleep
but never asserts it; update this test to assert that mock_retry_sleep was
invoked the expected number of times (3 sleeps for 4 failed attempts) after
calling _get_federated_auth_credentials("test-id", "auth-context-token") and
after the RetryError assertion; reference the mock object name mock_retry_sleep
and the test function test_federated_auth_exhausts_retries_on_persistent_5xx
when adding mock_retry_sleep.assertEqual(mock_retry_sleep.call_count, 3) (or
equivalent assert_called_count).

Loading