Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dev = [
"pytest>=9.0.2",
"pytest-cov>=7.0.0",
"repo-mapper-rs>=0.3.0",
"returns>=0.27.0",
"ruff>=0.14.9",
"sphinx>=9.0.4",
"ty>=0.0.9",
Expand Down
25 changes: 22 additions & 3 deletions src/io_adapters/_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from types import MappingProxyType

import attrs
from attrs.validators import deep_mapping, instance_of, is_callable, optional
from attrs.validators import deep_iterable, deep_mapping, instance_of, is_callable, optional

from io_adapters._clock import default_datetime, default_guid, fake_datetime, fake_guid
from io_adapters._registries import READ_FNS, WRITE_FNS, Data, ReadFn, WriteFn, standardise_key
Expand Down Expand Up @@ -203,10 +203,21 @@ class FakeAdapter(IoAdapter):
),
converter=_convert_file_mapping,
)
read_decs: tuple[Callable[..., ReadFn]] = attrs.field(
factory=tuple, validator=deep_iterable(is_callable(), instance_of(tuple)), converter=tuple
)
write_decs: tuple[Callable[..., WriteFn]] = attrs.field(
factory=tuple, validator=deep_iterable(is_callable(), instance_of(tuple)), converter=tuple
)

def __attrs_post_init__(self) -> None:
self.read_fns = MappingProxyType(dict.fromkeys(self.read_fns.keys(), self._read_fn))
self.write_fns = MappingProxyType(dict.fromkeys(self.write_fns.keys(), self._write_fn))
self.read_fns = MappingProxyType(
dict.fromkeys(self.read_fns.keys(), _apply_decs(self._read_fn, self.read_decs))
)
self.write_fns = MappingProxyType(
dict.fromkeys(self.write_fns.keys(), _apply_decs(self._write_fn, self.write_decs))
)

self.guid_fn = self.guid_fn or fake_guid
self.datetime_fn = self.datetime_fn or fake_datetime

Expand Down Expand Up @@ -244,3 +255,11 @@ def delete_file(self, path: str | Path, *, missing_ok: bool = True) -> None:

def exists(self, path: str | Path) -> bool:
return Path(path).resolve() in map(Path, self.files)


def _apply_decs(
fn: ReadFn | WriteFn, decs: tuple[Callable[..., ReadFn | WriteFn]]
) -> ReadFn | WriteFn:
for dec in reversed(decs):
fn = dec(fn)
return fn
71 changes: 71 additions & 0 deletions tests/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from pathlib import Path

import pytest
from returns.result import Success, safe

from src.io_adapters import FakeAdapter, RealAdapter
from src.io_adapters._adapters import _apply_decs

REPO_ROOT = Path(__file__).parents[1]
MOCK_DATA_PATH = f"{REPO_ROOT}/tests/mock_data/mock.json"
Expand Down Expand Up @@ -191,3 +193,72 @@ def test_write_then_list(adapter):
assert adapter.list_files(f"{TMP_ROOT}/pending") == [
Path(f"{TMP_ROOT}/pending/20260425_211300_000.json")
]


def test_apply_decs() -> None:
def append(lst: list[int], value: int):
def wrapper(fn):
lst.append(value)
return fn

return wrapper

dec_lst, fn_lst = [], []

@append(dec_lst, 1)
@append(dec_lst, 2)
def blah():
return None

_apply_decs(fn=lambda x: x, decs=(append(fn_lst, 1), append(fn_lst, 2)))

assert dec_lst == fn_lst


@safe
def safe_read_py_file(path: str) -> str:
return Path(path).read_text()


def read_py_file(path: str) -> str:
return Path(path).read_text()


SAFE_FNS = {"safe_py": safe_read_py_file, "py": read_py_file}


@pytest.mark.parametrize(
("adapter", "file_type", "expected_result"),
[
pytest.param(
RealAdapter(read_fns=SAFE_FNS),
"safe_py",
Success(""),
id="reads file with monad using a RealAdapter",
),
pytest.param(
RealAdapter(read_fns=SAFE_FNS), "py", "", id="reads file normally using a RealAdapter"
),
pytest.param(
FakeAdapter(
files=dict.fromkeys(map(str, INITIAL_FILES), ""),
read_fns=SAFE_FNS,
read_decs=[safe],
),
"safe_py",
Success(""),
id="reads file with monad using a FakeAdapter",
),
pytest.param(
FakeAdapter(
files=dict.fromkeys(map(str, INITIAL_FILES), ""),
read_fns=SAFE_FNS,
),
"py",
"",
id="reads file normally using a FakeAdapter",
),
],
)
def test_adapters_with_decs(adapter, file_type, expected_result):
assert adapter.read(INITIAL_FILES[0], file_type) == expected_result
14 changes: 14 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.