From d5e3298bcb34d1fd943ebb5b8e91e03dc50ac2e4 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Wed, 20 May 2026 11:27:20 +0200 Subject: [PATCH 1/3] multiprocessing --- pyproject.toml | 2 +- src/easyreflectometry/calculators/factory.py | 32 +++++ src/easyreflectometry/fitting.py | 9 ++ tests/calculators/test_factory.py | 48 +++++++ tests/test_fitting.py | 141 +++++++++++++++++++ 5 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 tests/calculators/test_factory.py diff --git a/pyproject.toml b/pyproject.toml index 2713a73f..bab61791 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ ] requires-python = '>=3.11' dependencies = [ - 'easyscience @ git+https://github.com/easyscience/corelib.git@bayesian', + 'easyscience @ git+https://github.com/easyscience/corelib.git@bayesian_mp', # 'easyscience', 'scipp', 'refnx', diff --git a/src/easyreflectometry/calculators/factory.py b/src/easyreflectometry/calculators/factory.py index c3e1479c..571db1a6 100644 --- a/src/easyreflectometry/calculators/factory.py +++ b/src/easyreflectometry/calculators/factory.py @@ -14,6 +14,38 @@ def __init__(self): """Init function.""" super().__init__(interface_list=CalculatorBase._calculators) + def __reduce__(self): + """Serialize the active calculator state for worker processes.""" + wrapper = getattr(self(), '_wrapper', None) + wrapper_state = None + if wrapper is not None: + wrapper_state = { + 'storage': wrapper.storage, + 'resolution_function': wrapper._resolution_function, + 'magnetism': wrapper._magnetism, + } + return ( + self.__state_restore__, + ( + self.__class__, + self.current_interface_name, + wrapper_state, + ), + ) + + @staticmethod + def __state_restore__(cls, interface_str, wrapper_state): + """Restore a calculator factory with its active wrapper state.""" + obj = cls() + if interface_str in obj.available_interfaces: + obj.switch(interface_str) + wrapper = getattr(obj(), '_wrapper', None) + if wrapper is not None and wrapper_state is not None: + wrapper.storage = wrapper_state['storage'] + wrapper._resolution_function = wrapper_state['resolution_function'] + wrapper._magnetism = wrapper_state['magnetism'] + return obj + def reset_storage(self) -> None: """Reset storage.""" return self().reset_storage() diff --git a/src/easyreflectometry/fitting.py b/src/easyreflectometry/fitting.py index 3d50fe99..09498082 100644 --- a/src/easyreflectometry/fitting.py +++ b/src/easyreflectometry/fitting.py @@ -364,6 +364,7 @@ def sample( seed: int | None = None, objective: str | None = None, initializer: str | None = None, + n_workers: int | None = None, progress_callback=None, abort_test=None, ) -> dict: @@ -383,8 +384,15 @@ def sample( :param initializer: DREAM population initializer. One of ``'eps'``, ``'cov'``, ``'lhs'``, or ``'random'``. By default, None (BUMPS uses ``'eps'``). + :param n_workers: Number of worker processes for parallel DREAM + population evaluation. ``None`` (default) and ``1`` use + sequential evaluation. Values greater than ``1`` enable + multiprocessing; the effective pool size is capped at + ``min(n_workers, population)``. :param progress_callback: Optional callback for progress updates during sampling. Forwarded to the core MultiFitter. + :param abort_test: Optional callback that returns ``True`` to signal + that sampling should be aborted. :return: Dictionary with keys ``'draws'``, ``'param_names'``, ``'state'``, and ``'logp'``. :raises RuntimeError: If the current minimizer is not a BUMPS instance. @@ -428,6 +436,7 @@ def sample( population=population, seed=seed, sampler_kwargs=sampler_kwargs or None, + n_workers=n_workers, progress_callback=progress_callback, abort_test=abort_test, ) diff --git a/tests/calculators/test_factory.py b/tests/calculators/test_factory.py new file mode 100644 index 00000000..3af5e062 --- /dev/null +++ b/tests/calculators/test_factory.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for CalculatorFactory serialization.""" + +import pickle # noqa: S403 + +import numpy as np +from numpy.testing import assert_allclose + +from easyreflectometry.calculators import CalculatorFactory +from easyreflectometry.model import Model +from easyreflectometry.model import PercentageFwhm +from easyreflectometry.sample import Layer +from easyreflectometry.sample import Material +from easyreflectometry.sample import Multilayer +from easyreflectometry.sample import Sample + + +def test_calculator_factory_pickle_preserves_active_wrapper_storage(): + """Pickled calculator factories retain model storage for worker processes.""" + si = Material(sld=2.07, isld=0.0, name='Si') + film = Material(sld=2.0, isld=0.0, name='Film') + d2o = Material(sld=6.36, isld=0.0, name='D2O') + + sample = Sample( + Multilayer(Layer(material=si, thickness=0.0, roughness=3.0, name='Si')), + Multilayer(Layer(material=film, thickness=250.0, roughness=3.0, name='Film')), + Multilayer(Layer(material=d2o, thickness=0.0, roughness=3.0, name='D2O')), + ) + model = Model( + sample=sample, + scale=1.0, + background=1e-6, + resolution_function=PercentageFwhm(0.02), + ) + interface = CalculatorFactory() + interface.switch('refnx') + model.interface = interface + + restored = pickle.loads(pickle.dumps(interface)) # noqa: S301 + + assert model.unique_name in restored()._wrapper.storage['model'] + q = np.linspace(0.01, 0.3, 10) + assert_allclose( + restored.fit_func(q, model.unique_name), + interface.fit_func(q, model.unique_name), + ) diff --git a/tests/test_fitting.py b/tests/test_fitting.py index a860fdb4..efa547dc 100644 --- a/tests/test_fitting.py +++ b/tests/test_fitting.py @@ -1034,3 +1034,144 @@ def _fake_sample(*, x, y, weights, **kwargs): fitter.sample(data, samples=100, burn=20, thin=2, objective='hybrid') assert len(captured['x'][0]) == 10 # all points kept (Mighell-substituted) + + +class TestSampleWorkers: + """n_workers parameter forwarding in sample().""" + + def test_default_is_none(self): + """When n_workers is not passed, it defaults to None (sequential).""" + model = Model() + model.interface = CalculatorFactory() + fitter = MultiFitter(model) + + captured = {} + + def _fake_sample(*, n_workers, **kwargs): + captured['n_workers'] = n_workers + return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} + + fitter.easy_science_multi_fitter = MagicMock() + fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample) + + data = sc.DataGroup({ + 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, + 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, + }) + + fitter.sample(data, samples=100, burn=20, thin=2) + assert captured['n_workers'] is None + + def test_explicit_none(self): + """Explicit n_workers=None is forwarded as None.""" + model = Model() + model.interface = CalculatorFactory() + fitter = MultiFitter(model) + + captured = {} + + def _fake_sample(*, n_workers, **kwargs): + captured['n_workers'] = n_workers + return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} + + fitter.easy_science_multi_fitter = MagicMock() + fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample) + + data = sc.DataGroup({ + 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, + 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, + }) + + fitter.sample(data, samples=100, burn=20, thin=2, n_workers=None) + assert captured['n_workers'] is None + + def test_explicit_one(self): + """n_workers=1 is forwarded (sequential, same as None).""" + model = Model() + model.interface = CalculatorFactory() + fitter = MultiFitter(model) + + captured = {} + + def _fake_sample(*, n_workers, **kwargs): + captured['n_workers'] = n_workers + return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} + + fitter.easy_science_multi_fitter = MagicMock() + fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample) + + data = sc.DataGroup({ + 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, + 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, + }) + + fitter.sample(data, samples=100, burn=20, thin=2, n_workers=1) + assert captured['n_workers'] == 1 + + @pytest.mark.parametrize('workers', [2, 4, 8]) + def test_multiple_workers_forwarded(self, workers): + """n_workers values greater than 1 are forwarded to core.""" + model = Model() + model.interface = CalculatorFactory() + fitter = MultiFitter(model) + + captured = {} + + def _fake_sample(*, n_workers, **kwargs): + captured['n_workers'] = n_workers + return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} + + fitter.easy_science_multi_fitter = MagicMock() + fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample) + + data = sc.DataGroup({ + 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, + 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, + }) + + fitter.sample(data, samples=100, burn=20, thin=2, n_workers=workers) + assert captured['n_workers'] == workers + + def test_with_other_params_combined(self): + """n_workers can be combined with all other sample() parameters.""" + model = Model() + model.interface = CalculatorFactory() + fitter = MultiFitter(model) + + captured = {} + + def _fake_sample(*, samples, burn, thin, population, seed, n_workers, sampler_kwargs, **kwargs): + captured['samples'] = samples + captured['burn'] = burn + captured['thin'] = thin + captured['population'] = population + captured['seed'] = seed + captured['n_workers'] = n_workers + captured['sampler_kwargs'] = sampler_kwargs + return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} + + fitter.easy_science_multi_fitter = MagicMock() + fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample) + + data = sc.DataGroup({ + 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, + 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, + }) + + fitter.sample( + data, + samples=500, + burn=100, + thin=5, + population=8, + seed=42, + initializer='cov', + n_workers=4, + ) + assert captured['samples'] == 500 + assert captured['burn'] == 100 + assert captured['thin'] == 5 + assert captured['population'] == 8 + assert captured['seed'] == 42 + assert captured['n_workers'] == 4 + assert captured['sampler_kwargs'] == {'init': 'cov'} From 72a51b42ea20910aed28033d0e8040f56e003899 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Thu, 21 May 2026 19:39:29 +0200 Subject: [PATCH 2/3] PR review addressed --- src/easyreflectometry/calculators/factory.py | 25 ++- .../calculators/wrapper_base.py | 12 ++ src/easyreflectometry/fitting.py | 35 ++-- tests/test_fitting.py | 173 ++++++------------ 4 files changed, 101 insertions(+), 144 deletions(-) diff --git a/src/easyreflectometry/calculators/factory.py b/src/easyreflectometry/calculators/factory.py index 571db1a6..74ce98de 100644 --- a/src/easyreflectometry/calculators/factory.py +++ b/src/easyreflectometry/calculators/factory.py @@ -17,33 +17,30 @@ def __init__(self): def __reduce__(self): """Serialize the active calculator state for worker processes.""" wrapper = getattr(self(), '_wrapper', None) - wrapper_state = None - if wrapper is not None: - wrapper_state = { - 'storage': wrapper.storage, - 'resolution_function': wrapper._resolution_function, - 'magnetism': wrapper._magnetism, - } + if wrapper is None and self.current_interface_name is not None: + raise RuntimeError( + f'Cannot pickle CalculatorFactory: active interface ' + f"{self.current_interface_name!r} exposes no '_wrapper' attribute. " + 'The InterfaceFactoryTemplate API may have changed.' + ) return ( - self.__state_restore__, + self._state_restore, ( self.__class__, self.current_interface_name, - wrapper_state, + wrapper.__getstate__() if wrapper is not None else None, ), ) @staticmethod - def __state_restore__(cls, interface_str, wrapper_state): + def _state_restore(cls, interface_str, wrapper_state): """Restore a calculator factory with its active wrapper state.""" obj = cls() - if interface_str in obj.available_interfaces: + if interface_str is not None and interface_str in obj.available_interfaces: obj.switch(interface_str) wrapper = getattr(obj(), '_wrapper', None) if wrapper is not None and wrapper_state is not None: - wrapper.storage = wrapper_state['storage'] - wrapper._resolution_function = wrapper_state['resolution_function'] - wrapper._magnetism = wrapper_state['magnetism'] + wrapper.__setstate__(wrapper_state) return obj def reset_storage(self) -> None: diff --git a/src/easyreflectometry/calculators/wrapper_base.py b/src/easyreflectometry/calculators/wrapper_base.py index dc53ceca..8b383dca 100644 --- a/src/easyreflectometry/calculators/wrapper_base.py +++ b/src/easyreflectometry/calculators/wrapper_base.py @@ -293,6 +293,18 @@ def get_item_value(self, name: str, key: str) -> float: item = getattr(item, key) return getattr(item, 'value') + def __getstate__(self) -> dict: + return { + 'storage': self.storage, + 'resolution_function': self._resolution_function, + 'magnetism': self._magnetism, + } + + def __setstate__(self, state: dict) -> None: + self.storage = state['storage'] + self._resolution_function = state['resolution_function'] + self._magnetism = state['magnetism'] + def set_resolution_function(self, resolution_function: ResolutionFunction) -> None: """Set the resolution function for the calculator. diff --git a/src/easyreflectometry/fitting.py b/src/easyreflectometry/fitting.py index 09498082..db0c81fe 100644 --- a/src/easyreflectometry/fitting.py +++ b/src/easyreflectometry/fitting.py @@ -396,7 +396,10 @@ def sample( :return: Dictionary with keys ``'draws'``, ``'param_names'``, ``'state'``, and ``'logp'``. :raises RuntimeError: If the current minimizer is not a BUMPS instance. + :raises ValueError: If ``n_workers`` is not None and less than 1. """ + if n_workers is not None and n_workers < 1: + raise ValueError(f'n_workers must be a positive integer or None, got {n_workers}') obj = _validate_objective(objective) if objective is not None else self._objective refl_nums = [k[3:] for k in data['coords'].keys() if 'Qz' == k[:2]] @@ -425,21 +428,23 @@ def sample( sampler_kwargs = {} if initializer is not None: sampler_kwargs['init'] = initializer - return self.easy_science_multi_fitter.sample( - x=x, - y=y, - weights=dy, - samples=samples, - burn=burn, - thin=thin, - chains=chains, - population=population, - seed=seed, - sampler_kwargs=sampler_kwargs or None, - n_workers=n_workers, - progress_callback=progress_callback, - abort_test=abort_test, - ) + core_sample_kwargs = { + 'x': x, + 'y': y, + 'weights': dy, + 'samples': samples, + 'burn': burn, + 'thin': thin, + 'chains': chains, + 'population': population, + 'seed': seed, + 'sampler_kwargs': sampler_kwargs or None, + 'progress_callback': progress_callback, + 'abort_test': abort_test, + } + if n_workers is not None: + core_sample_kwargs['n_workers'] = n_workers + return self.easy_science_multi_fitter.sample(**core_sample_kwargs) @property def chi2(self) -> float | None: diff --git a/tests/test_fitting.py b/tests/test_fitting.py index efa547dc..28ead1ab 100644 --- a/tests/test_fitting.py +++ b/tests/test_fitting.py @@ -1036,142 +1036,85 @@ def _fake_sample(*, x, y, weights, **kwargs): assert len(captured['x'][0]) == 10 # all points kept (Mighell-substituted) +_SENTINEL = object() + + class TestSampleWorkers: """n_workers parameter forwarding in sample().""" - def test_default_is_none(self): - """When n_workers is not passed, it defaults to None (sequential).""" + @pytest.fixture + def sample_fitter(self): model = Model() model.interface = CalculatorFactory() fitter = MultiFitter(model) - - captured = {} - - def _fake_sample(*, n_workers, **kwargs): - captured['n_workers'] = n_workers - return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} - - fitter.easy_science_multi_fitter = MagicMock() - fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample) - data = sc.DataGroup({ 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, }) + return fitter, data - fitter.sample(data, samples=100, burn=20, thin=2) - assert captured['n_workers'] is None - - def test_explicit_none(self): - """Explicit n_workers=None is forwarded as None.""" - model = Model() - model.interface = CalculatorFactory() - fitter = MultiFitter(model) + def _mock_sample(self, fitter, captured): + """Wire a fake sample() that records n_workers into captured.""" - captured = {} - - def _fake_sample(*, n_workers, **kwargs): + def _fake(*, n_workers=_SENTINEL, **kwargs): captured['n_workers'] = n_workers return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} fitter.easy_science_multi_fitter = MagicMock() - fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample) - - data = sc.DataGroup({ - 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, - 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, - }) - - fitter.sample(data, samples=100, burn=20, thin=2, n_workers=None) - assert captured['n_workers'] is None - - def test_explicit_one(self): - """n_workers=1 is forwarded (sequential, same as None).""" - model = Model() - model.interface = CalculatorFactory() - fitter = MultiFitter(model) - - captured = {} - - def _fake_sample(*, n_workers, **kwargs): - captured['n_workers'] = n_workers - return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} - - fitter.easy_science_multi_fitter = MagicMock() - fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample) - - data = sc.DataGroup({ - 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, - 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, - }) - - fitter.sample(data, samples=100, burn=20, thin=2, n_workers=1) - assert captured['n_workers'] == 1 - - @pytest.mark.parametrize('workers', [2, 4, 8]) - def test_multiple_workers_forwarded(self, workers): - """n_workers values greater than 1 are forwarded to core.""" - model = Model() - model.interface = CalculatorFactory() - fitter = MultiFitter(model) - + fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake) + + @pytest.mark.parametrize( + 'kwargs,expected', + [ + ({}, _SENTINEL), + ({'n_workers': None}, _SENTINEL), + ({'n_workers': 1}, 1), + ({'n_workers': 2}, 2), + ({'n_workers': 8}, 8), + ], + ) + def test_n_workers_forwarded(self, sample_fitter, kwargs, expected): + """n_workers is forwarded to core only when explicitly set to ≥1.""" + fitter, data = sample_fitter captured = {} + self._mock_sample(fitter, captured) + fitter.sample(data, samples=100, burn=20, thin=2, **kwargs) + assert captured['n_workers'] == expected - def _fake_sample(*, n_workers, **kwargs): - captured['n_workers'] = n_workers - return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} - - fitter.easy_science_multi_fitter = MagicMock() - fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample) - - data = sc.DataGroup({ - 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, - 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, - }) - - fitter.sample(data, samples=100, burn=20, thin=2, n_workers=workers) - assert captured['n_workers'] == workers - - def test_with_other_params_combined(self): + def test_with_other_params_combined(self, sample_fitter): """n_workers can be combined with all other sample() parameters.""" - model = Model() - model.interface = CalculatorFactory() - fitter = MultiFitter(model) - + fitter, data = sample_fitter captured = {} - def _fake_sample(*, samples, burn, thin, population, seed, n_workers, sampler_kwargs, **kwargs): - captured['samples'] = samples - captured['burn'] = burn - captured['thin'] = thin - captured['population'] = population - captured['seed'] = seed - captured['n_workers'] = n_workers - captured['sampler_kwargs'] = sampler_kwargs + def _fake(*, samples, burn, thin, population, seed, n_workers, sampler_kwargs, **kwargs): + captured.update( + samples=samples, + burn=burn, + thin=thin, + population=population, + seed=seed, + n_workers=n_workers, + sampler_kwargs=sampler_kwargs, + ) return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} fitter.easy_science_multi_fitter = MagicMock() - fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample) - - data = sc.DataGroup({ - 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, - 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, - }) - - fitter.sample( - data, - samples=500, - burn=100, - thin=5, - population=8, - seed=42, - initializer='cov', - n_workers=4, - ) - assert captured['samples'] == 500 - assert captured['burn'] == 100 - assert captured['thin'] == 5 - assert captured['population'] == 8 - assert captured['seed'] == 42 - assert captured['n_workers'] == 4 - assert captured['sampler_kwargs'] == {'init': 'cov'} + fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake) + + fitter.sample(data, samples=500, burn=100, thin=5, population=8, seed=42, initializer='cov', n_workers=4) + assert captured == { + 'samples': 500, + 'burn': 100, + 'thin': 5, + 'population': 8, + 'seed': 42, + 'n_workers': 4, + 'sampler_kwargs': {'init': 'cov'}, + } + + @pytest.mark.parametrize('bad', [0, -1, -100]) + def test_invalid_n_workers_raises(self, sample_fitter, bad): + """n_workers < 1 raises ValueError before reaching the core.""" + fitter, data = sample_fitter + with pytest.raises(ValueError, match='n_workers'): + fitter.sample(data, samples=100, burn=20, thin=2, n_workers=bad) From 738117aae0a25ce5183182a77807da812afb502a Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Thu, 21 May 2026 20:44:46 +0200 Subject: [PATCH 3/3] move corner and arviz to dev --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bab61791..f1c1e8dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,10 +69,10 @@ dev = [ 'mkdocstrings-python', # MkDocs: Python docstring support 'pyyaml', # YAML parser 'spdx-headers', # SPDX license header validation + 'corner', # Bayesian analysis and plotting + 'arviz', # Bayesian analysis and plotting ] -bayesian = ["corner>=2.2", "arviz>=0.18"] - [project.urls] Documentation = 'https://easyscience.github.io/reflectometry-lib' 'Release Notes' = 'https://github.com/easyscience/reflectometry-lib/releases'