Skip to content

Commit a5e9eb7

Browse files
committed
Small fixes and feedback
1 parent 079fbdb commit a5e9eb7

3 files changed

Lines changed: 69 additions & 10 deletions

File tree

msal/oauth2cli/oauth2.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,23 @@ def __init__(
176176

177177
@staticmethod
178178
def _accepts_context(func):
179-
"""Check if a callable accepts at least one positional argument."""
179+
"""Check if a callable requires at least one positional argument.
180+
181+
Returns True only when the callable has a positional parameter
182+
**without** a default value. This ensures that legacy zero-arg
183+
callables — including ``lambda token=token: token`` patterns
184+
where every positional param has a default — are still invoked
185+
with no arguments.
186+
"""
180187
try:
181188
sig = inspect.signature(func)
182-
params = [
183-
p for p in sig.parameters.values()
189+
for p in sig.parameters.values():
184190
if p.kind in (
185191
inspect.Parameter.POSITIONAL_ONLY,
186192
inspect.Parameter.POSITIONAL_OR_KEYWORD,
187-
)
188-
]
189-
return len(params) >= 1
193+
) and p.default is inspect.Parameter.empty:
194+
return True
195+
return False
190196
except (ValueError, TypeError):
191197
return False # Signature not inspectable; treat as zero-arg
192198

msal/token_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def make_clean_copy(dictionary, sensitive_fields): # Masks sensitive info
306306
event,
307307
data=make_clean_copy(event.get("data", {}), (
308308
"password", "client_secret", "refresh_token", "assertion",
309+
"user_federated_identity_credential",
309310
)),
310311
response=make_clean_copy(event.get("response", {}), (
311312
"id_token_claims", # Provided by broker

tests/test_application.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
22
# so this test_application file contains only unit tests without dependency.
3+
import base64
34
import json
45
import logging
56
import sys
@@ -1120,9 +1121,6 @@ def _build_user_fic_response(uid="user_oid", utid="tenant_id", access_token="use
11201121
})
11211122

11221123

1123-
import base64
1124-
1125-
11261124
@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK)
11271125
class TestUserFicProtocol(unittest.TestCase):
11281126
"""Tests that acquire_token_by_user_federated_identity_credential sends correct POST body."""
@@ -1265,6 +1263,34 @@ def mock_post(url, headers=None, data=None, *args, **kwargs):
12651263
self.assertIn("access_token", silent_result)
12661264
self.assertEqual("cached_fic_at", silent_result["access_token"])
12671265

1266+
def test_oid_path_token_stored_and_retrievable_via_silent(self):
1267+
"""user_fic with user_object_id should cache and retrieve like username."""
1268+
app = self._make_app()
1269+
1270+
def mock_post(url, headers=None, data=None, *args, **kwargs):
1271+
return MinimalResponse(status_code=200, text=_build_user_fic_response(
1272+
uid="user_oid", utid="tenant_id", access_token="oid_fic_at"))
1273+
1274+
result = app.acquire_token_by_user_federated_identity_credential(
1275+
["https://graph.microsoft.com/.default"],
1276+
assertion="t2", user_object_id="user_oid", post=mock_post)
1277+
self.assertIn("access_token", result)
1278+
1279+
# Verify no ext_cache_key on cached token
1280+
at_entries = list(app.token_cache.search(
1281+
msal.TokenCache.CredentialType.ACCESS_TOKEN, query={}))
1282+
self.assertTrue(len(at_entries) > 0, "AT should be cached")
1283+
self.assertNotIn("ext_cache_key", at_entries[0],
1284+
"OID-path user_fic tokens should NOT have ext_cache_key")
1285+
1286+
# Verify account and silent retrieval
1287+
accounts = app.get_accounts()
1288+
self.assertTrue(len(accounts) > 0)
1289+
silent_result = app.acquire_token_silent(
1290+
["https://graph.microsoft.com/.default"], account=accounts[0])
1291+
self.assertIn("access_token", silent_result)
1292+
self.assertEqual("oid_fic_at", silent_result["access_token"])
1293+
12681294

12691295
@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK)
12701296
class TestUserFicInputValidation(unittest.TestCase):
@@ -1391,4 +1417,30 @@ def buggy_callback(context):
13911417
["scope"],
13921418
post=lambda url, **kwargs: MinimalResponse(
13931419
status_code=200, text=json.dumps({
1394-
"access_token": "an_at", "expires_in": 3600})))
1420+
"access_token": "an_at", "expires_in": 3600})))
1421+
1422+
def test_lambda_with_defaulted_param_treated_as_zero_arg(self):
1423+
"""A lambda like ``lambda token=token: token`` should be treated as
1424+
zero-arg because all its positional params have defaults."""
1425+
captured_value = "my_assertion_value"
1426+
assertion_callable = lambda token=captured_value: token # noqa: E731
1427+
1428+
app = ConfidentialClientApplication(
1429+
"client_id",
1430+
client_credential={"client_assertion": assertion_callable},
1431+
authority="https://login.microsoftonline.com/my_tenant")
1432+
1433+
captured_data = {}
1434+
def mock_post(url, headers=None, data=None, *args, **kwargs):
1435+
captured_data.update(data or {})
1436+
return MinimalResponse(
1437+
status_code=200, text=json.dumps({
1438+
"access_token": "an_at", "expires_in": 3600}))
1439+
1440+
result = app.acquire_token_for_client(["scope"], post=mock_post)
1441+
self.assertIn("access_token", result)
1442+
# The assertion should be the string value, not a dict context object
1443+
self.assertEqual(
1444+
captured_value, captured_data.get("client_assertion"),
1445+
"Lambda with defaulted params should return its default value, "
1446+
"not receive a context dict")

0 commit comments

Comments
 (0)