Skip to content

Commit 03b4d17

Browse files
fix(resolver): make resolver cache thread-safe with per-identifier locking
Add per-identifier locks to _find_cached_candidates() and return defensive copies from _get_cached_candidates() to prevent concurrent threads from corrupting cached candidate lists during parallel builds. A single global lock would serialize all resolution work, so a per-identifier scheme is used instead — threads resolving different packages proceed concurrently while threads resolving the same package wait for the first to populate the cache. Closes: #1024 Co-Authored-By: Claude <claude@anthropic.com> Signed-off-by: Lalatendu Mohanty <lmohanty@redhat.com>
1 parent b5df8e2 commit 03b4d17

2 files changed

Lines changed: 150 additions & 21 deletions

File tree

src/fromager/resolver.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
import os
1212
import re
13+
import threading
1314
import typing
1415
from collections.abc import Iterable
1516
from operator import attrgetter
@@ -423,6 +424,8 @@ def get_project_from_pypi(
423424

424425
class BaseProvider(ExtrasProvider):
425426
resolver_cache: typing.ClassVar[ResolverCache] = {}
427+
_cache_locks: typing.ClassVar[dict[str, threading.Lock]] = {}
428+
_meta_lock: typing.ClassVar[threading.Lock] = threading.Lock()
426429
provider_description: typing.ClassVar[str]
427430

428431
def __init__(
@@ -552,22 +555,62 @@ def get_dependencies(self, candidate: Candidate) -> list[Requirement]:
552555
# return candidate.dependencies
553556
return []
554557

558+
def _get_identifier_lock(self, identifier: str) -> threading.Lock:
559+
"""Get or create a per-identifier lock for thread-safe cache access.
560+
561+
Uses a short-lived meta-lock to protect the lock dict itself.
562+
The per-identifier lock ensures threads resolving different packages
563+
proceed concurrently, while threads resolving the same package
564+
wait for the first to populate the cache.
565+
"""
566+
with self._meta_lock:
567+
if identifier not in self._cache_locks:
568+
self._cache_locks[identifier] = threading.Lock()
569+
return self._cache_locks[identifier]
570+
555571
def _get_cached_candidates(self, identifier: str) -> list[Candidate]:
556-
"""Get list of cached candidates for identifier and provider
572+
"""Get a copy of cached candidates for identifier and provider.
557573
558574
The method always returns a list. If the cache did not have an entry
559-
before, a new empty list is stored in the cache and returned to the
560-
caller. The caller can mutate the list in place to update the cache.
575+
before, a new empty list is stored in the cache. A copy is returned
576+
so callers cannot accidentally corrupt the cache.
577+
578+
Must be called under the per-identifier lock from _get_identifier_lock.
561579
"""
562580
cls = type(self)
563581
provider_cache = cls.resolver_cache.setdefault(identifier, {})
564582
candidate_cache = provider_cache.setdefault((cls, self.cache_key), [])
565-
return candidate_cache
583+
return list(candidate_cache)
584+
585+
def _set_cached_candidates(
586+
self, identifier: str, candidates: list[Candidate]
587+
) -> None:
588+
"""Store candidates in the cache for identifier and provider.
589+
590+
Must be called under the per-identifier lock from _get_identifier_lock.
591+
"""
592+
cls = type(self)
593+
provider_cache = cls.resolver_cache.setdefault(identifier, {})
594+
provider_cache[(cls, self.cache_key)] = list(candidates)
566595

567596
def _find_cached_candidates(self, identifier: str) -> Candidates:
568-
"""Find candidates with caching"""
569-
cached_candidates: list[Candidate] = []
570-
if self.use_cache_candidates:
597+
"""Find candidates with caching.
598+
599+
Uses a per-identifier lock so threads resolving different packages
600+
proceed concurrently, while threads resolving the same package
601+
wait for the first to populate the cache.
602+
"""
603+
if not self.use_cache_candidates:
604+
candidates = list(self.find_candidates(identifier))
605+
logger.debug(
606+
"%s: got %i unfiltered candidates, ignoring cache",
607+
identifier,
608+
len(candidates),
609+
)
610+
return candidates
611+
612+
lock = self._get_identifier_lock(identifier)
613+
with lock:
571614
cached_candidates = self._get_cached_candidates(identifier)
572615
if cached_candidates:
573616
logger.debug(
@@ -576,22 +619,15 @@ def _find_cached_candidates(self, identifier: str) -> Candidates:
576619
len(cached_candidates),
577620
)
578621
return cached_candidates
579-
candidates = list(self.find_candidates(identifier))
580-
if self.use_cache_candidates:
581-
# mutate list object in-place
582-
cached_candidates[:] = candidates
622+
623+
candidates = list(self.find_candidates(identifier))
624+
self._set_cached_candidates(identifier, candidates)
583625
logger.debug(
584626
"%s: cache %i unfiltered candidates",
585627
identifier,
586628
len(candidates),
587629
)
588-
else:
589-
logger.debug(
590-
"%s: got %i unfiltered candidates, ignoring cache",
591-
identifier,
592-
len(candidates),
593-
)
594-
return candidates
630+
return candidates
595631

596632
def _get_no_match_error_message(
597633
self, identifier: str, requirements: RequirementsMap

tests/test_resolver.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import datetime
22
import re
3+
import threading
4+
import time
35
import typing
46

57
import pytest
@@ -11,6 +13,7 @@
1113

1214
from fromager import constraints, resolver
1315
from fromager.__main__ import main as fromager
16+
from fromager.candidate import Candidate
1417

1518
_hydra_core_simple_response = """
1619
<!DOCTYPE html>
@@ -153,10 +156,8 @@ def test_provider_cache_key_pypi(pypi_hydra_resolver: typing.Any) -> None:
153156
resolver_cache = resolver.BaseProvider.resolver_cache
154157
assert req.name in resolver_cache
155158
assert (resolver.PyPIProvider, provider.cache_key) in resolver_cache[req.name]
156-
# mutated in place
157-
assert provider._get_cached_candidates(req.name) is req_cache
159+
# _get_cached_candidates returns a defensive copy, not the same object
158160
assert len(provider._get_cached_candidates(req.name)) == 7
159-
assert len(req_cache) == 7
160161

161162

162163
def test_provider_cache_key_gitlab(gitlab_decile_resolver: typing.Any) -> None:
@@ -1278,3 +1279,95 @@ def test_cli_package_resolver(
12781279
assert "- PyPI versions: 1.2.2, 1.3.1+local, 1.3.2, 2.0.0a1" in result.stdout
12791280
assert "- only wheels on PyPI: 1.3.1+local, 2.0.0a1" in result.stdout
12801281
assert "- missing from Fromager: 1.3.1+local, 2.0.0a1" in result.stdout
1282+
1283+
1284+
def _make_candidate(name: str, version: str) -> Candidate:
1285+
"""Create a minimal Candidate for testing."""
1286+
return Candidate(
1287+
name=name, version=Version(version), url="https://example.com", is_sdist=False
1288+
)
1289+
1290+
1291+
class _StubProvider(resolver.BaseProvider):
1292+
"""Minimal BaseProvider subclass for cache tests."""
1293+
1294+
provider_description = "stub"
1295+
1296+
@property
1297+
def cache_key(self) -> str:
1298+
return "stub-key"
1299+
1300+
def find_candidates(self, identifier: str) -> list[Candidate]:
1301+
return []
1302+
1303+
1304+
def test_get_cached_candidates_returns_defensive_copy() -> None:
1305+
"""Mutating the list returned by _get_cached_candidates must not corrupt the cache."""
1306+
provider = _StubProvider()
1307+
identifier = "test-pkg"
1308+
1309+
# Seed the cache directly so the test doesn't depend on the aliasing bug
1310+
resolver.BaseProvider.resolver_cache[identifier] = {
1311+
(type(provider), provider.cache_key): [_make_candidate("test-pkg", "1.0.0")]
1312+
}
1313+
1314+
# Get candidates again and mutate the returned list
1315+
first = provider._get_cached_candidates(identifier)
1316+
first.append(_make_candidate("test-pkg", "2.0.0"))
1317+
1318+
# The cache should not reflect the caller's mutation
1319+
second = provider._get_cached_candidates(identifier)
1320+
assert len(second) == 1, (
1321+
"_get_cached_candidates should return a defensive copy, "
1322+
"not a direct reference to the internal cache"
1323+
)
1324+
assert second[0].version == Version("1.0.0")
1325+
1326+
1327+
def test_find_cached_candidates_thread_safe() -> None:
1328+
"""Concurrent threads must not bypass the cache and call find_candidates multiple times."""
1329+
call_count = 0
1330+
call_count_lock = threading.Lock()
1331+
1332+
class _SlowProvider(resolver.BaseProvider):
1333+
"""Provider with a slow find_candidates to expose thread races."""
1334+
1335+
provider_description = "slow"
1336+
1337+
@property
1338+
def cache_key(self) -> str:
1339+
return "slow-key"
1340+
1341+
def find_candidates(self, identifier: str) -> list[Candidate]:
1342+
nonlocal call_count
1343+
with call_count_lock:
1344+
call_count += 1
1345+
time.sleep(0.2)
1346+
return [_make_candidate(identifier, "1.0.0")]
1347+
1348+
barrier = threading.Barrier(4)
1349+
1350+
def resolve_in_thread(provider: _SlowProvider, ident: str) -> None:
1351+
barrier.wait(timeout=5)
1352+
list(provider._find_cached_candidates(ident))
1353+
1354+
providers = [_SlowProvider() for _ in range(4)]
1355+
threads = [
1356+
threading.Thread(
1357+
target=resolve_in_thread,
1358+
args=(thread_provider, "shared-pkg"),
1359+
name=f"resolver-{i}",
1360+
)
1361+
for i, thread_provider in enumerate(providers)
1362+
]
1363+
1364+
for t in threads:
1365+
t.start()
1366+
for t in threads:
1367+
t.join(timeout=10)
1368+
1369+
assert call_count == 1, (
1370+
f"find_candidates() was called {call_count} times; expected 1. "
1371+
"Without thread-safe caching, multiple threads bypass the cache "
1372+
"and redundantly call find_candidates()."
1373+
)

0 commit comments

Comments
 (0)