Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 72 additions & 22 deletions src/fromager/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import os
import re
import threading
import typing
from collections.abc import Iterable
from operator import attrgetter
Expand Down Expand Up @@ -422,7 +423,21 @@ def get_project_from_pypi(


class BaseProvider(ExtrasProvider):
"""Base class for Fromager's dependency resolver (resolvelib + extras).

Subclasses implement ``find_candidates``, ``cache_key``, and
``provider_description`` to list versions from PyPI, a version map, etc.

Candidate lists are cached per package in one global dict, with a lock per
package so parallel work on different packages does not clash.

``find_matches`` keeps only versions that fit the requirements and
constraints, then picks newest first.
"""

resolver_cache: typing.ClassVar[ResolverCache] = {}
_cache_locks: typing.ClassVar[dict[str, threading.Lock]] = {}
_meta_lock: typing.ClassVar[threading.Lock] = threading.Lock()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_cache_locks grows unbounded as new identifiers are resolved. Locks are never cleaned up.

We will potentially see long-running processes or large dependency graphs soon when multiple version bootstrap is enabled. This will accumulate locks for every package ever resolved.

Can we add lock cleanup for clear_cache()? Something like:

def clear_cache(cls, identifier: str | None = None) -> None:
      """Clear global resolver cache and associated locks."""
      with cls._meta_lock:
          if identifier is None:
              cls.resolver_cache.clear()
              cls._cache_locks.clear()
          else:
              canon_name = canonicalize_name(identifier)
              cls.resolver_cache.pop(canon_name)
              cls._cache_locks.pop(canon_name, None)  # Lock may not exist

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. We can add the lock cleanup. I was also thinking of adding this but the existing code was not causing the situation when this is an issue. Will take a deeper look.

provider_description: typing.ClassVar[str]

def __init__(
Expand Down Expand Up @@ -552,46 +567,81 @@ def get_dependencies(self, candidate: Candidate) -> list[Requirement]:
# return candidate.dependencies
return []

def _get_cached_candidates(self, identifier: str) -> list[Candidate]:
"""Get list of cached candidates for identifier and provider
def _get_identifier_lock(self, identifier: str) -> threading.Lock:
"""Get or create a per-identifier lock for thread-safe cache access.

Uses a short-lived meta-lock to protect the lock dict itself.
The per-identifier lock ensures threads resolving different packages
proceed concurrently, while threads resolving the same package
wait for the first to populate the cache.
"""
with self._meta_lock:
if identifier not in self._cache_locks:
self._cache_locks[identifier] = threading.Lock()
return self._cache_locks[identifier]

def _get_cached_candidates(self, identifier: str) -> list[Candidate] | None:
"""Get a copy of cached candidates for identifier and provider.

Returns None if no entry exists in the cache, or a copy of the cached
list (which may be empty). A copy is returned so callers cannot
accidentally corrupt the cache.

Must be called under the per-identifier lock from _get_identifier_lock.
"""
cls = type(self)
provider_cache = cls.resolver_cache.setdefault(identifier, {})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates an empty dict even when just checking if an identifier is cached.

Suggestion:

provider_cache = cls.resolver_cache.get(identifier, {})
  candidate_cache = provider_cache.get((cls, self.cache_key))

candidate_cache = provider_cache.get((cls, self.cache_key))
if candidate_cache is None:
return None
return list(candidate_cache)

def _set_cached_candidates(
self, identifier: str, candidates: list[Candidate]
) -> None:
"""Store candidates in the cache for identifier and provider.

The method always returns a list. If the cache did not have an entry
before, a new empty list is stored in the cache and returned to the
caller. The caller can mutate the list in place to update the cache.
Must be called under the per-identifier lock from _get_identifier_lock.
"""
cls = type(self)
provider_cache = cls.resolver_cache.setdefault(identifier, {})
candidate_cache = provider_cache.setdefault((cls, self.cache_key), [])
return candidate_cache
provider_cache[(cls, self.cache_key)] = list(candidates)

def _find_cached_candidates(self, identifier: str) -> Candidates:
"""Find candidates with caching"""
cached_candidates: list[Candidate] = []
if self.use_cache_candidates:
"""Find candidates with caching.

Uses a per-identifier lock so threads resolving different packages
proceed concurrently, while threads resolving the same package
wait for the first to populate the cache.
"""
if not self.use_cache_candidates:
candidates = list(self.find_candidates(identifier))
logger.debug(
"%s: got %i unfiltered candidates, ignoring cache",
identifier,
len(candidates),
)
return candidates

lock = self._get_identifier_lock(identifier)
with lock:
cached_candidates = self._get_cached_candidates(identifier)
if cached_candidates:
if cached_candidates is not None:
logger.debug(
"%s: use %i cached candidates",
identifier,
len(cached_candidates),
)
return cached_candidates
candidates = list(self.find_candidates(identifier))
if self.use_cache_candidates:
# mutate list object in-place
cached_candidates[:] = candidates

candidates = list(self.find_candidates(identifier))
self._set_cached_candidates(identifier, candidates)
logger.debug(
"%s: cache %i unfiltered candidates",
identifier,
len(candidates),
)
else:
logger.debug(
"%s: got %i unfiltered candidates, ignoring cache",
identifier,
len(candidates),
)
return candidates
return candidates

def _get_no_match_error_message(
self, identifier: str, requirements: RequirementsMap
Expand Down
103 changes: 99 additions & 4 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime
import re
import threading
import time
import typing

import pytest
Expand All @@ -11,6 +13,7 @@

from fromager import constraints, resolver
from fromager.__main__ import main as fromager
from fromager.candidate import Candidate

_hydra_core_simple_response = """
<!DOCTYPE html>
Expand Down Expand Up @@ -144,7 +147,7 @@ def test_provider_cache_key_pypi(pypi_hydra_resolver: typing.Any) -> None:
provider = pypi_hydra_resolver.provider
assert provider.cache_key == "https://pypi.org/simple/"
req_cache = provider._get_cached_candidates(req.name)
assert req_cache == []
assert req_cache is None

result = pypi_hydra_resolver.resolve([req])
candidate = result.mapping[req.name]
Expand All @@ -153,10 +156,8 @@ def test_provider_cache_key_pypi(pypi_hydra_resolver: typing.Any) -> None:
resolver_cache = resolver.BaseProvider.resolver_cache
assert req.name in resolver_cache
assert (resolver.PyPIProvider, provider.cache_key) in resolver_cache[req.name]
# mutated in place
assert provider._get_cached_candidates(req.name) is req_cache
# _get_cached_candidates returns a defensive copy, not the same object
assert len(provider._get_cached_candidates(req.name)) == 7
assert len(req_cache) == 7


def test_provider_cache_key_gitlab(gitlab_decile_resolver: typing.Any) -> None:
Expand Down Expand Up @@ -1278,3 +1279,97 @@ def test_cli_package_resolver(
assert "- PyPI versions: 1.2.2, 1.3.1+local, 1.3.2, 2.0.0a1" in result.stdout
assert "- only wheels on PyPI: 1.3.1+local, 2.0.0a1" in result.stdout
assert "- missing from Fromager: 1.3.1+local, 2.0.0a1" in result.stdout


def _make_candidate(name: str, version: str) -> Candidate:
"""Create a minimal Candidate for testing."""
return Candidate(
name=name, version=Version(version), url="https://example.com", is_sdist=False
)


class _StubProvider(resolver.BaseProvider):
"""Minimal BaseProvider subclass for cache tests."""

provider_description = "stub"

@property
def cache_key(self) -> str:
return "stub-key"

def find_candidates(self, identifier: str) -> list[Candidate]:
return []


def test_get_cached_candidates_returns_defensive_copy() -> None:
"""Mutating the list returned by _get_cached_candidates must not corrupt the cache."""
provider = _StubProvider()
identifier = "test-pkg"

# Seed the cache directly so the test doesn't depend on the aliasing bug
resolver.BaseProvider.resolver_cache[identifier] = {
(type(provider), provider.cache_key): [_make_candidate("test-pkg", "1.0.0")]
}

# Get candidates again and mutate the returned list
first = provider._get_cached_candidates(identifier)
assert first is not None
first.append(_make_candidate("test-pkg", "2.0.0"))

# The cache should not reflect the caller's mutation
second = provider._get_cached_candidates(identifier)
assert second is not None
assert len(second) == 1, (
"_get_cached_candidates should return a defensive copy, "
"not a direct reference to the internal cache"
)
assert second[0].version == Version("1.0.0")


def test_find_cached_candidates_thread_safe() -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_find_cached_candidates_thread_safe() only tests threads resolving the same package. Maybe we should also verify that different packages don't block each other?

"""Concurrent threads must not bypass the cache and call find_candidates multiple times."""
call_count = 0
call_count_lock = threading.Lock()

class _SlowProvider(resolver.BaseProvider):
"""Provider with a slow find_candidates to expose thread races."""

provider_description = "slow"

@property
def cache_key(self) -> str:
return "slow-key"

def find_candidates(self, identifier: str) -> list[Candidate]:
nonlocal call_count
with call_count_lock:
call_count += 1
time.sleep(0.2)
return [_make_candidate(identifier, "1.0.0")]

barrier = threading.Barrier(4)

def resolve_in_thread(provider: _SlowProvider, ident: str) -> None:
barrier.wait(timeout=5)
list(provider._find_cached_candidates(ident))

providers = [_SlowProvider() for _ in range(4)]
threads = [
threading.Thread(
target=resolve_in_thread,
args=(thread_provider, "shared-pkg"),
name=f"resolver-{i}",
)
for i, thread_provider in enumerate(providers)
]

for t in threads:
t.start()
for t in threads:
t.join(timeout=10)

assert call_count == 1, (
f"find_candidates() was called {call_count} times; expected 1. "
"Without thread-safe caching, multiple threads bypass the cache "
"and redundantly call find_candidates()."
)
Loading