Skip to content

Commit ce578ff

Browse files
polar3130xuanyang15
authored andcommitted
feat(apigee): allow injecting credentials into ApigeeLlm
Merge #4722 Close #4721 Co-authored-by: Xuan Yang <xygoogle@google.com> COPYBARA_INTEGRATE_REVIEW=#4722 from polar3130:feat/apigee-llm-userinfo-email-scope f746117 PiperOrigin-RevId: 910144731
1 parent 69fa777 commit ce578ff

2 files changed

Lines changed: 75 additions & 0 deletions

File tree

src/google/adk/models/apigee_llm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .llm_response import LlmResponse
4141

4242
if TYPE_CHECKING:
43+
from google.auth.credentials import Credentials
4344
from google.genai import Client
4445

4546
from .llm_request import LlmRequest
@@ -92,6 +93,7 @@ def __init__(
9293
custom_headers: dict[str, str] | None = None,
9394
retry_options: Optional[types.HttpRetryOptions] = None,
9495
api_type: ApiType | str = ApiType.UNKNOWN,
96+
credentials: Credentials | None = None,
9597
):
9698
"""Initializes the Apigee LLM backend.
9799
@@ -123,6 +125,11 @@ def __init__(
123125
authorization headers in Vertex AI and Gemini API calls.
124126
retry_options: Allow google-genai to retry failed responses.
125127
api_type: The type of API to use. One of `ApiType` or string.
128+
credentials: Optional google-auth credentials passed through to the
129+
underlying `genai.Client`. Use this when the Apigee proxy requires
130+
additional OAuth scopes (e.g., `userinfo.email` for tokeninfo-based
131+
caller identification). When omitted, the default `genai.Client`
132+
authentication flow is used.
126133
""" # fmt: skip
127134

128135
super().__init__(model=model, retry_options=retry_options)
@@ -165,6 +172,7 @@ def __init__(
165172
)
166173
self._custom_headers = custom_headers or {}
167174
self._user_agent = f'google-adk/{adk_version.__version__}'
175+
self._credentials = credentials
168176

169177
@classmethod
170178
@override
@@ -239,6 +247,8 @@ def api_client(self) -> Client:
239247
if self._isvertexai:
240248
kwargs_for_client['project'] = self._project
241249
kwargs_for_client['location'] = self._location
250+
if self._credentials is not None:
251+
kwargs_for_client['credentials'] = self._credentials
242252

243253
return Client(
244254
http_options=http_options,

tests/unittests/models/test_apigee_llm.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,71 @@ def test_parse_response_usage_metadata():
651651
assert llm_response.usage_metadata.thoughts_token_count == 4
652652

653653

654+
@pytest.mark.asyncio
655+
@mock.patch('google.genai.Client')
656+
async def test_api_client_passes_credentials_when_provided(
657+
mock_client_constructor, llm_request
658+
):
659+
"""Tests that credentials passed to __init__ are forwarded to genai.Client."""
660+
mock_credentials = mock.Mock()
661+
662+
mock_client_instance = mock.Mock()
663+
mock_client_instance.aio.models.generate_content = AsyncMock(
664+
return_value=types.GenerateContentResponse(
665+
candidates=[
666+
types.Candidate(
667+
content=Content(
668+
parts=[Part.from_text(text='Test response')],
669+
role='model',
670+
)
671+
)
672+
]
673+
)
674+
)
675+
mock_client_constructor.return_value = mock_client_instance
676+
677+
apigee_llm = ApigeeLlm(
678+
model=APIGEE_GEMINI_MODEL_ID,
679+
proxy_url=PROXY_URL,
680+
credentials=mock_credentials,
681+
)
682+
_ = [resp async for resp in apigee_llm.generate_content_async(llm_request)]
683+
684+
_, kwargs = mock_client_constructor.call_args
685+
assert kwargs['credentials'] is mock_credentials
686+
687+
688+
@pytest.mark.asyncio
689+
@mock.patch('google.genai.Client')
690+
async def test_api_client_omits_credentials_when_not_provided(
691+
mock_client_constructor, llm_request
692+
):
693+
"""Tests that credentials kwarg is not forwarded when not supplied."""
694+
mock_client_instance = mock.Mock()
695+
mock_client_instance.aio.models.generate_content = AsyncMock(
696+
return_value=types.GenerateContentResponse(
697+
candidates=[
698+
types.Candidate(
699+
content=Content(
700+
parts=[Part.from_text(text='Test response')],
701+
role='model',
702+
)
703+
)
704+
]
705+
)
706+
)
707+
mock_client_constructor.return_value = mock_client_instance
708+
709+
apigee_llm = ApigeeLlm(
710+
model=APIGEE_GEMINI_MODEL_ID,
711+
proxy_url=PROXY_URL,
712+
)
713+
_ = [resp async for resp in apigee_llm.generate_content_async(llm_request)]
714+
715+
_, kwargs = mock_client_constructor.call_args
716+
assert 'credentials' not in kwargs
717+
718+
654719
def test_parse_response_with_refusal():
655720
"""Tests that CompletionsHTTPClient parses refusal correctly."""
656721
client = CompletionsHTTPClient(base_url='http://test')

0 commit comments

Comments
 (0)