From bb7bcfd3fb79bf737f0e861fdaef282960d70d46 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Fri, 9 May 2025 08:34:01 +0200 Subject: [PATCH 01/16] initial commit --- src/easyreflectometry/model/__init__.py | 2 + src/easyreflectometry/model/model.py | 1 + .../model/resolution_functions.py | 49 +++++++++++++++++++ src/easyreflectometry/project.py | 15 ++++-- tests/model/test_resolution_functions.py | 35 +++++++++++++ tests/test_project.py | 3 +- 6 files changed, 100 insertions(+), 5 deletions(-) diff --git a/src/easyreflectometry/model/__init__.py b/src/easyreflectometry/model/__init__.py index baa1aec4..40cef321 100644 --- a/src/easyreflectometry/model/__init__.py +++ b/src/easyreflectometry/model/__init__.py @@ -3,10 +3,12 @@ from .resolution_functions import LinearSpline from .resolution_functions import PercentageFwhm from .resolution_functions import ResolutionFunction +from .resolution_functions import Pointwise __all__ = ( "LinearSpline", "PercentageFwhm", + "Pointwise", "ResolutionFunction", "Model", "ModelCollection", diff --git a/src/easyreflectometry/model/model.py b/src/easyreflectometry/model/model.py index 969c955a..a06bfc00 100644 --- a/src/easyreflectometry/model/model.py +++ b/src/easyreflectometry/model/model.py @@ -19,6 +19,7 @@ from .resolution_functions import PercentageFwhm from .resolution_functions import ResolutionFunction +from .resolution_functions import Pointwise DEFAULTS = { 'scale': { diff --git a/src/easyreflectometry/model/resolution_functions.py b/src/easyreflectometry/model/resolution_functions.py index 3a352fc5..6dd3ad7e 100644 --- a/src/easyreflectometry/model/resolution_functions.py +++ b/src/easyreflectometry/model/resolution_functions.py @@ -30,6 +30,8 @@ def from_dict(cls, data: dict) -> ResolutionFunction: return PercentageFwhm(data['constant']) if data['smearing'] == 'LinearSpline': return LinearSpline(data['q_data_points'], data['fwhm_values']) + if data['smearing'] == 'Pointwise': + return Pointwise(data['q_data_points']) raise ValueError('Unknown resolution function type') @@ -60,3 +62,50 @@ def as_dict( self, skip: Optional[List[str]] = None ) -> dict[str, str]: # skip is kept for consistency of the as_dict signature return {'smearing': 'LinearSpline', 'q_data_points': list(self.q_data_points), 'fwhm_values': list(self.fwhm_values)} + +# add pointwise smearing funtion +class Pointwise(ResolutionFunction): + def __init__(self, q_data_points: np.array): + self.q_data_points = q_data_points + + def smearing(self, q: Union[np.array, float] = 0.0) -> np.array: + + Qz= self.q_data_points[0] + R= self.q_data_points[1] + sR= self.q_data_points[2] + sQz= self.q_data_points[3] + smeared = self.apply_smooth_smearing(Qz, R, sR, sQz) + return smeared + + def as_dict( + self, skip: Optional[List[str]] = None + ) -> dict[str, str]: # skip is kept for consistency of the as_dict signature + return {'smearing': 'Pointwise', 'q_data_points': list(self.q_data_points)} + + def gaussian_kernel(self, x, sigma): + """Simple Gaussian kernel function""" + return np.exp(-x**2/(2*sigma**2)) + + def apply_smooth_smearing(self, Qz, R, sR, sQz, n_sigma=3): + """ + Apply smooth resolution smearing using convolution with Gaussian kernel. + """ + R_smeared = np.zeros_like(R) + if not isinstance(Qz, np.ndarray): + Qz = np.array(Qz) + if not isinstance(R, np.ndarray): + R = np.array(R) + for i, (q, r, sr, sq) in enumerate(zip(Qz, R, sR, sQz)): + weights = self.gaussian_kernel(Qz - q, sq) + mask = np.abs(Qz - q) <= n_sigma * sq + weights[~mask] = 0 + + if np.sum(weights) > 0: + weights = weights / np.sum(weights) + + R_smeared[i] = np.sum(R * weights) + # Potentially add the pointwise error from sR + # This can also be used as error bands + # R_smeared[i] += np.random.normal(0, sr * weights[i]) + + return R_smeared \ No newline at end of file diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index de769ab8..93a3f74f 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -22,6 +22,7 @@ from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm +from easyreflectometry.model import Pointwise from easyreflectometry.sample import Layer from easyreflectometry.sample import Material from easyreflectometry.sample import MaterialCollection @@ -249,10 +250,16 @@ def load_experiment_for_model_at_index(self, path: Union[Path, str], index: Opti # Set the resolution function if variance data is present if sum(self._experiments[index].ye) != 0: - resolution_function = LinearSpline( - q_data_points=self._experiments[index].y, - fwhm_values=np.sqrt(self._experiments[index].ye), - ) + q = self._experiments[index].x + reflectivity = self._experiments[index].y + reflectivity_error = self._experiments[index].ye + q_error = self._experiments[index].xe + resolution_function = Pointwise( + q_data_points=[q, reflectivity, reflectivity_error, q_error]) + # resolution_function = LinearSpline( + # q_data_points=self._experiments[index].y, + # fwhm_values=np.sqrt(self._experiments[index].ye), + # ) self._models[index].resolution_function = resolution_function def sld_data_for_model_at_index(self, index: int = 0) -> DataSet1D: diff --git a/tests/model/test_resolution_functions.py b/tests/model/test_resolution_functions.py index f2963c48..bcf4e0ac 100644 --- a/tests/model/test_resolution_functions.py +++ b/tests/model/test_resolution_functions.py @@ -6,6 +6,7 @@ from easyreflectometry.model.resolution_functions import LinearSpline from easyreflectometry.model.resolution_functions import PercentageFwhm from easyreflectometry.model.resolution_functions import ResolutionFunction +from easyreflectometry.model.resolution_functions import Pointwise class TestPercentageFwhm(unittest.TestCase): @@ -75,3 +76,37 @@ def test_dict_round_trip(self): # Expect assert all(resolution_function.smearing([0, 2.5]) == expected_resolution_function.smearing([0, 2.5])) + +class TestPointwise(unittest.TestCase): + + data_points = [] + data_points.append([0.1, 0.2, 0.3, 0.4, 0.5]) # Qz + data_points.append([1.1, 2.2, 3.3, 4.4, 5.5]) # R + data_points.append([0.01, 0.02, 0.01, 0.05, 0.08]) # sR + data_points.append([0.03, 0.04, 0.05, 0.06, 0.07]) # sQz + def test_constructor(self): + + # When + resolution_function = Pointwise(q_data_points=self.data_points) + + # Then Expect + assert np.allclose(np.array(resolution_function.smearing()), np.array([1.1, 2.2, 3.3, 4.4, 5.18516692])) + + + def test_as_dict(self): + # When + resolution_function = Pointwise(q_data_points=self.data_points) + + # Then Expect + resolution_function.as_dict() == {'smearing': 'Pointwise', 'q_data_points': [0, 10]} + + def test_dict_round_trip(self): + # When + expected_resolution_function = Pointwise(q_data_points=self.data_points) + res_dict = expected_resolution_function.as_dict() + + # Then + resolution_function = ResolutionFunction.from_dict(res_dict) + + # Expect + assert all(resolution_function.smearing() == expected_resolution_function.smearing()) diff --git a/tests/test_project.py b/tests/test_project.py index 523138c3..f1ded171 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -16,6 +16,7 @@ from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm +from easyreflectometry.model import Pointwise from easyreflectometry.project import Project from easyreflectometry.sample import Material from easyreflectometry.sample import MaterialCollection @@ -561,7 +562,7 @@ def test_load_experiment(self): assert isinstance(project.experiments[5], DataSet1D) assert project.experiments[5].name == 'Experiment for Model 5' assert project.experiments[5].model == model_5 - assert isinstance(project.models[5].resolution_function, LinearSpline) + assert isinstance(project.models[5].resolution_function, Pointwise) assert isinstance(project.models[4].resolution_function, PercentageFwhm) def test_experimental_data_at_index(self): From c8f70d152b56982723057aa217b3d45a68461be3 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Fri, 9 May 2025 09:47:23 +0200 Subject: [PATCH 02/16] ruff --- src/easyreflectometry/model/__init__.py | 2 +- src/easyreflectometry/model/model.py | 1 - src/easyreflectometry/project.py | 1 - tests/model/test_resolution_functions.py | 2 +- tests/test_project.py | 1 - 5 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/easyreflectometry/model/__init__.py b/src/easyreflectometry/model/__init__.py index 40cef321..b12b504e 100644 --- a/src/easyreflectometry/model/__init__.py +++ b/src/easyreflectometry/model/__init__.py @@ -2,8 +2,8 @@ from .model_collection import ModelCollection from .resolution_functions import LinearSpline from .resolution_functions import PercentageFwhm -from .resolution_functions import ResolutionFunction from .resolution_functions import Pointwise +from .resolution_functions import ResolutionFunction __all__ = ( "LinearSpline", diff --git a/src/easyreflectometry/model/model.py b/src/easyreflectometry/model/model.py index a06bfc00..969c955a 100644 --- a/src/easyreflectometry/model/model.py +++ b/src/easyreflectometry/model/model.py @@ -19,7 +19,6 @@ from .resolution_functions import PercentageFwhm from .resolution_functions import ResolutionFunction -from .resolution_functions import Pointwise DEFAULTS = { 'scale': { diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index 93a3f74f..787877fa 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -18,7 +18,6 @@ from easyreflectometry.data import DataSet1D from easyreflectometry.data import load_as_dataset from easyreflectometry.fitting import MultiFitter -from easyreflectometry.model import LinearSpline from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm diff --git a/tests/model/test_resolution_functions.py b/tests/model/test_resolution_functions.py index bcf4e0ac..0a609826 100644 --- a/tests/model/test_resolution_functions.py +++ b/tests/model/test_resolution_functions.py @@ -5,8 +5,8 @@ from easyreflectometry.model.resolution_functions import DEFAULT_RESOLUTION_FWHM_PERCENTAGE from easyreflectometry.model.resolution_functions import LinearSpline from easyreflectometry.model.resolution_functions import PercentageFwhm -from easyreflectometry.model.resolution_functions import ResolutionFunction from easyreflectometry.model.resolution_functions import Pointwise +from easyreflectometry.model.resolution_functions import ResolutionFunction class TestPercentageFwhm(unittest.TestCase): diff --git a/tests/test_project.py b/tests/test_project.py index f1ded171..bfa6c233 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -12,7 +12,6 @@ import easyreflectometry from easyreflectometry.data import DataSet1D from easyreflectometry.fitting import MultiFitter -from easyreflectometry.model import LinearSpline from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm From ddc90a6a7c0e50822117558cdaf470d4704c0356 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Fri, 9 May 2025 10:48:56 +0200 Subject: [PATCH 03/16] fixed unit tests --- src/easyreflectometry/project.py | 22 +++++++++++----------- tests/test_project.py | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index 787877fa..6ba43b28 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -18,10 +18,10 @@ from easyreflectometry.data import DataSet1D from easyreflectometry.data import load_as_dataset from easyreflectometry.fitting import MultiFitter +from easyreflectometry.model import LinearSpline from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm -from easyreflectometry.model import Pointwise from easyreflectometry.sample import Layer from easyreflectometry.sample import Material from easyreflectometry.sample import MaterialCollection @@ -249,16 +249,16 @@ def load_experiment_for_model_at_index(self, path: Union[Path, str], index: Opti # Set the resolution function if variance data is present if sum(self._experiments[index].ye) != 0: - q = self._experiments[index].x - reflectivity = self._experiments[index].y - reflectivity_error = self._experiments[index].ye - q_error = self._experiments[index].xe - resolution_function = Pointwise( - q_data_points=[q, reflectivity, reflectivity_error, q_error]) - # resolution_function = LinearSpline( - # q_data_points=self._experiments[index].y, - # fwhm_values=np.sqrt(self._experiments[index].ye), - # ) + # q = self._experiments[index].x + # reflectivity = self._experiments[index].y + # reflectivity_error = self._experiments[index].ye + # q_error = self._experiments[index].xe + # resolution_function = Pointwise( + # q_data_points=[q, reflectivity, reflectivity_error, q_error]) + resolution_function = LinearSpline( + q_data_points=self._experiments[index].y, + fwhm_values=np.sqrt(self._experiments[index].ye), + ) self._models[index].resolution_function = resolution_function def sld_data_for_model_at_index(self, index: int = 0) -> DataSet1D: diff --git a/tests/test_project.py b/tests/test_project.py index bfa6c233..523138c3 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -12,10 +12,10 @@ import easyreflectometry from easyreflectometry.data import DataSet1D from easyreflectometry.fitting import MultiFitter +from easyreflectometry.model import LinearSpline from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm -from easyreflectometry.model import Pointwise from easyreflectometry.project import Project from easyreflectometry.sample import Material from easyreflectometry.sample import MaterialCollection @@ -561,7 +561,7 @@ def test_load_experiment(self): assert isinstance(project.experiments[5], DataSet1D) assert project.experiments[5].name == 'Experiment for Model 5' assert project.experiments[5].model == model_5 - assert isinstance(project.models[5].resolution_function, Pointwise) + assert isinstance(project.models[5].resolution_function, LinearSpline) assert isinstance(project.models[4].resolution_function, PercentageFwhm) def test_experimental_data_at_index(self): From ade4d9a017d2d315a628ce4c0d4a3c6067d68d66 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Fri, 23 May 2025 13:06:47 +0200 Subject: [PATCH 04/16] updated smearing --- .../model/resolution_functions.py | 68 ++++++++++++------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/src/easyreflectometry/model/resolution_functions.py b/src/easyreflectometry/model/resolution_functions.py index 6dd3ad7e..7fd7f28c 100644 --- a/src/easyreflectometry/model/resolution_functions.py +++ b/src/easyreflectometry/model/resolution_functions.py @@ -11,6 +11,7 @@ from typing import List from typing import Optional from typing import Union +import scipp as sc import numpy as np @@ -65,16 +66,21 @@ def as_dict( # add pointwise smearing funtion class Pointwise(ResolutionFunction): - def __init__(self, q_data_points: np.array): + def __init__(self, q_data_points: sc.DataGroup): self.q_data_points = q_data_points + self.q = None def smearing(self, q: Union[np.array, float] = 0.0) -> np.array: - Qz= self.q_data_points[0] - R= self.q_data_points[1] - sR= self.q_data_points[2] - sQz= self.q_data_points[3] - smeared = self.apply_smooth_smearing(Qz, R, sR, sQz) + Qz = self.q_data_points[0] + R = self.q_data_points[1] + sQz = self.q_data_points[2] + self.q = q + sQzs = np.sqrt(sQz) + if not isinstance(q, np.ndarray): + q = np.ndarray(q) + + smeared = self.apply_smooth_smearing(Qz, R, sQzs) return smeared def as_dict( @@ -82,30 +88,42 @@ def as_dict( ) -> dict[str, str]: # skip is kept for consistency of the as_dict signature return {'smearing': 'Pointwise', 'q_data_points': list(self.q_data_points)} - def gaussian_kernel(self, x, sigma): - """Simple Gaussian kernel function""" - return np.exp(-x**2/(2*sigma**2)) + def gaussian_smearing(self, qt, Qz, R, sQz): + weights = np.exp(-0.5 * ((qt - Qz) / sQz) ** 2) + if np.sum(weights) == 0 or not np.isfinite(np.sum(weights)): + return R + weights /= (sQz * np.sqrt(2 * np.pi)) + return np.sum(R * weights) / np.sum(weights) + - def apply_smooth_smearing(self, Qz, R, sR, sQz, n_sigma=3): + def apply_smooth_smearing(self, Qz, R, sQzs): """ Apply smooth resolution smearing using convolution with Gaussian kernel. """ - R_smeared = np.zeros_like(R) + if self.q is None: + R_smeared = np.zeros_like(Qz) + else: + R_smeared = np.zeros_like(self.q) if not isinstance(Qz, np.ndarray): Qz = np.array(Qz) if not isinstance(R, np.ndarray): R = np.array(R) - for i, (q, r, sr, sq) in enumerate(zip(Qz, R, sR, sQz)): - weights = self.gaussian_kernel(Qz - q, sq) - mask = np.abs(Qz - q) <= n_sigma * sq - weights[~mask] = 0 - - if np.sum(weights) > 0: - weights = weights / np.sum(weights) - - R_smeared[i] = np.sum(R * weights) - # Potentially add the pointwise error from sR - # This can also be used as error bands - # R_smeared[i] += np.random.normal(0, sr * weights[i]) - - return R_smeared \ No newline at end of file + R_smeared = np.zeros_like(self.q) + + for i, qt in enumerate(self.q): + R_smeared[i] = self.gaussian_smearing(qt, Qz, R, sQzs) + + # TEST LOCALLY + # import matplotlib.pyplot as plt + # plt.figure(figsize=(10, 6)) + # plt.plot(Qz, R, label='Original R', marker='o', linestyle='none') + # plt.plot(self.q, R_smeared, label='Smeared R', linestyle='-') + # plt.yscale('log') + # plt.xlabel('Qz (1/angstrom)') + # plt.ylabel('R') + # plt.legend() + # plt.title('Original and Smeared R vs Qz (log scale)') + # plt.grid(True) + # plt.show() + + return R_smeared From e93b0f3b1c432c21802b6d9d9646bd5ea627aed9 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Mon, 26 May 2025 10:40:06 +0200 Subject: [PATCH 05/16] fix unit tests --- src/easyreflectometry/model/resolution_functions.py | 11 ++++++----- src/easyreflectometry/project.py | 2 +- tests/model/test_resolution_functions.py | 4 +--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/easyreflectometry/model/resolution_functions.py b/src/easyreflectometry/model/resolution_functions.py index 7fd7f28c..7fdf5e87 100644 --- a/src/easyreflectometry/model/resolution_functions.py +++ b/src/easyreflectometry/model/resolution_functions.py @@ -11,7 +11,6 @@ from typing import List from typing import Optional from typing import Union -import scipp as sc import numpy as np @@ -66,19 +65,21 @@ def as_dict( # add pointwise smearing funtion class Pointwise(ResolutionFunction): - def __init__(self, q_data_points: sc.DataGroup): + def __init__(self, q_data_points: list[np.ndarray]): self.q_data_points = q_data_points self.q = None - def smearing(self, q: Union[np.array, float] = 0.0) -> np.array: + def smearing(self, q: Union[np.ndarray, float] = None) -> np.ndarray: Qz = self.q_data_points[0] R = self.q_data_points[1] sQz = self.q_data_points[2] + if q is None: + q = self.q_data_points[0] self.q = q sQzs = np.sqrt(sQz) - if not isinstance(q, np.ndarray): - q = np.ndarray(q) + if isinstance(Qz, float): + Qz = np.array(Qz) smeared = self.apply_smooth_smearing(Qz, R, sQzs) return smeared diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index 6ba43b28..e183ea40 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -254,7 +254,7 @@ def load_experiment_for_model_at_index(self, path: Union[Path, str], index: Opti # reflectivity_error = self._experiments[index].ye # q_error = self._experiments[index].xe # resolution_function = Pointwise( - # q_data_points=[q, reflectivity, reflectivity_error, q_error]) + # q_data_points=[q, reflectivity, q_error]) resolution_function = LinearSpline( q_data_points=self._experiments[index].y, fwhm_values=np.sqrt(self._experiments[index].ye), diff --git a/tests/model/test_resolution_functions.py b/tests/model/test_resolution_functions.py index 0a609826..ab6e0e4e 100644 --- a/tests/model/test_resolution_functions.py +++ b/tests/model/test_resolution_functions.py @@ -82,7 +82,6 @@ class TestPointwise(unittest.TestCase): data_points = [] data_points.append([0.1, 0.2, 0.3, 0.4, 0.5]) # Qz data_points.append([1.1, 2.2, 3.3, 4.4, 5.5]) # R - data_points.append([0.01, 0.02, 0.01, 0.05, 0.08]) # sR data_points.append([0.03, 0.04, 0.05, 0.06, 0.07]) # sQz def test_constructor(self): @@ -90,8 +89,7 @@ def test_constructor(self): resolution_function = Pointwise(q_data_points=self.data_points) # Then Expect - assert np.allclose(np.array(resolution_function.smearing()), np.array([1.1, 2.2, 3.3, 4.4, 5.18516692])) - + assert np.allclose(np.array(resolution_function.smearing()), np.array([2.51664683, 2.84038734, 3.2460762 , 3.6796519 , 4.07869271])) def test_as_dict(self): # When From 876a5e4c9ad265594cc77a13c4bdb4ce0ac5bf01 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Mon, 26 May 2025 11:11:47 +0200 Subject: [PATCH 06/16] ruff fix for tests --- tests/model/test_resolution_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/model/test_resolution_functions.py b/tests/model/test_resolution_functions.py index ab6e0e4e..8e5d1a03 100644 --- a/tests/model/test_resolution_functions.py +++ b/tests/model/test_resolution_functions.py @@ -89,7 +89,8 @@ def test_constructor(self): resolution_function = Pointwise(q_data_points=self.data_points) # Then Expect - assert np.allclose(np.array(resolution_function.smearing()), np.array([2.51664683, 2.84038734, 3.2460762 , 3.6796519 , 4.07869271])) + assert np.allclose(np.array(resolution_function.smearing()), + np.array([2.51664683, 2.84038734, 3.2460762 , 3.6796519 , 4.07869271])) def test_as_dict(self): # When From 5d110a04ee8213d8c129935b7f91fa5551b416d1 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Wed, 23 Jul 2025 15:14:10 +0200 Subject: [PATCH 07/16] updated/fixed pointwise impl --- .../simulation/resolution_functions.ipynb | 69 +++++++++++++++++-- .../model/resolution_functions.py | 23 ++----- src/easyreflectometry/project.py | 21 +++--- tests/model/test_resolution_functions.py | 2 +- tests/summary/test_summary.py | 2 +- tests/test_project.py | 4 +- 6 files changed, 85 insertions(+), 36 deletions(-) diff --git a/docs/src/tutorials/simulation/resolution_functions.ipynb b/docs/src/tutorials/simulation/resolution_functions.ipynb index 237963e7..1aca5fab 100644 --- a/docs/src/tutorials/simulation/resolution_functions.ipynb +++ b/docs/src/tutorials/simulation/resolution_functions.ipynb @@ -46,6 +46,7 @@ "from easyreflectometry.model import Model\n", "from easyreflectometry.model import LinearSpline\n", "from easyreflectometry.model import PercentageFwhm\n", + "from easyreflectometry.model import Pointwise\n", "from easyreflectometry.sample import Layer\n", "from easyreflectometry.sample import Material\n", "from easyreflectometry.sample import Multilayer\n", @@ -115,6 +116,16 @@ "dict_reference['10'] = load(file_path_10)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5f65ed7", + "metadata": {}, + "outputs": [], + "source": [ + "dict_reference['0']" + ] + }, { "cell_type": "markdown", "id": "1ab3a164-62c8-4bd3-b0d8-e6f22c83dc74", @@ -251,9 +262,15 @@ "id": "defd6dd5-c618-4af6-a5c7-17532207f0a0", "metadata": {}, "source": [ - "## Resolution functions\n", - "\n", - "We now define the different resoultion functions. " + "## Resolution functions " + ] + }, + { + "cell_type": "markdown", + "id": "c9d903db", + "metadata": {}, + "source": [ + "We can now define the different resoultion functions. " ] }, { @@ -376,11 +393,53 @@ "plt.yscale('log')\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43881642", + "metadata": {}, + "outputs": [], + "source": [ + "key = '1'\n", + "reference_coords = dict_reference[key]['coords']['Qz_0'].values\n", + "reference_variances = dict_reference[key]['coords']['Qz_0'].variances\n", + "reference_data = dict_reference[key]['data']['R_0'].values\n", + "model_coords = np.linspace(\n", + " start=min(reference_coords),\n", + " stop=max(reference_coords),\n", + " num=1000,\n", + ")\n", + "\n", + "model.resolution_function = resolution_function_dict[key]\n", + "model_data = model.interface().reflectity_profile(\n", + " model_coords,\n", + " model.unique_name,\n", + ")\n", + "plt.plot(model_coords, model_data, 'k-', label=f'Variable', linewidth=5)\n", + "data_points = []\n", + "data_points.append(reference_coords) # Qz\n", + "data_points.append(reference_data) # R\n", + "data_points.append(reference_variances) # sQz\n", + "model.resolution_function = Pointwise(q_data_points=data_points)\n", + "model_data = model.interface().reflectity_profile(\n", + " model_coords,\n", + " model.unique_name,\n", + ")\n", + "plt.plot(model_coords, model_data, 'r-', label=f'Pointwise')\n", + "\n", + "ax = plt.gca()\n", + "ax.set_xlim([-0.01, 0.45])\n", + "ax.set_ylim([1e-10, 2.5])\n", + "plt.legend()\n", + "plt.yscale('log')\n", + "plt.show()" + ] } ], "metadata": { "kernelspec": { - "display_name": "easyref", + "display_name": "era", "language": "python", "name": "python3" }, @@ -394,7 +453,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.9" + "version": "3.12.11" } }, "nbformat": 4, diff --git a/src/easyreflectometry/model/resolution_functions.py b/src/easyreflectometry/model/resolution_functions.py index 7fdf5e87..736aec21 100644 --- a/src/easyreflectometry/model/resolution_functions.py +++ b/src/easyreflectometry/model/resolution_functions.py @@ -31,7 +31,7 @@ def from_dict(cls, data: dict) -> ResolutionFunction: if data['smearing'] == 'LinearSpline': return LinearSpline(data['q_data_points'], data['fwhm_values']) if data['smearing'] == 'Pointwise': - return Pointwise(data['q_data_points']) + return Pointwise([data['q_data_points'], data['R_data_points'], data['sQz_data_points']]) raise ValueError('Unknown resolution function type') @@ -87,12 +87,15 @@ def smearing(self, q: Union[np.ndarray, float] = None) -> np.ndarray: def as_dict( self, skip: Optional[List[str]] = None ) -> dict[str, str]: # skip is kept for consistency of the as_dict signature - return {'smearing': 'Pointwise', 'q_data_points': list(self.q_data_points)} + return {'smearing': 'Pointwise', + 'q_data_points': list(self.q_data_points[0]), + 'R_data_points': list(self.q_data_points[1]), + 'sQz_data_points': list(self.q_data_points[2])} def gaussian_smearing(self, qt, Qz, R, sQz): weights = np.exp(-0.5 * ((qt - Qz) / sQz) ** 2) if np.sum(weights) == 0 or not np.isfinite(np.sum(weights)): - return R + return np.sum(R) weights /= (sQz * np.sqrt(2 * np.pi)) return np.sum(R * weights) / np.sum(weights) @@ -105,6 +108,7 @@ def apply_smooth_smearing(self, Qz, R, sQzs): R_smeared = np.zeros_like(Qz) else: R_smeared = np.zeros_like(self.q) + if not isinstance(Qz, np.ndarray): Qz = np.array(Qz) if not isinstance(R, np.ndarray): @@ -114,17 +118,4 @@ def apply_smooth_smearing(self, Qz, R, sQzs): for i, qt in enumerate(self.q): R_smeared[i] = self.gaussian_smearing(qt, Qz, R, sQzs) - # TEST LOCALLY - # import matplotlib.pyplot as plt - # plt.figure(figsize=(10, 6)) - # plt.plot(Qz, R, label='Original R', marker='o', linestyle='none') - # plt.plot(self.q, R_smeared, label='Smeared R', linestyle='-') - # plt.yscale('log') - # plt.xlabel('Qz (1/angstrom)') - # plt.ylabel('R') - # plt.legend() - # plt.title('Original and Smeared R vs Qz (log scale)') - # plt.grid(True) - # plt.show() - return R_smeared diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index e183ea40..66fe90c2 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -18,10 +18,10 @@ from easyreflectometry.data import DataSet1D from easyreflectometry.data import load_as_dataset from easyreflectometry.fitting import MultiFitter -from easyreflectometry.model import LinearSpline from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm +from easyreflectometry.model import Pointwise from easyreflectometry.sample import Layer from easyreflectometry.sample import Material from easyreflectometry.sample import MaterialCollection @@ -249,16 +249,15 @@ def load_experiment_for_model_at_index(self, path: Union[Path, str], index: Opti # Set the resolution function if variance data is present if sum(self._experiments[index].ye) != 0: - # q = self._experiments[index].x - # reflectivity = self._experiments[index].y - # reflectivity_error = self._experiments[index].ye - # q_error = self._experiments[index].xe - # resolution_function = Pointwise( - # q_data_points=[q, reflectivity, q_error]) - resolution_function = LinearSpline( - q_data_points=self._experiments[index].y, - fwhm_values=np.sqrt(self._experiments[index].ye), - ) + q = self._experiments[index].x + reflectivity = self._experiments[index].y + q_error = self._experiments[index].xe + resolution_function = Pointwise( + q_data_points=[q, reflectivity, q_error]) + # resolution_function = LinearSpline( + # q_data_points=self._experiments[index].y, + # fwhm_values=np.sqrt(self._experiments[index].ye), + # ) self._models[index].resolution_function = resolution_function def sld_data_for_model_at_index(self, index: int = 0) -> DataSet1D: diff --git a/tests/model/test_resolution_functions.py b/tests/model/test_resolution_functions.py index 8e5d1a03..b5b1c6d7 100644 --- a/tests/model/test_resolution_functions.py +++ b/tests/model/test_resolution_functions.py @@ -97,7 +97,7 @@ def test_as_dict(self): resolution_function = Pointwise(q_data_points=self.data_points) # Then Expect - resolution_function.as_dict() == {'smearing': 'Pointwise', 'q_data_points': [0, 10]} + assert resolution_function.as_dict(), {'smearing': 'Pointwise', 'q_data_points': [0, 10]} def test_dict_round_trip(self): # When diff --git a/tests/summary/test_summary.py b/tests/summary/test_summary.py index ef8ad80e..fa018b02 100644 --- a/tests/summary/test_summary.py +++ b/tests/summary/test_summary.py @@ -133,7 +133,7 @@ def test_experiments_section(self, project: Project) -> None: assert 'No. of data points' in html assert '408' in html assert 'Resolution function' in html - assert 'LinearSpline' in html + assert 'Pointwise' in html def test_experiments_section_percentage_fhwm(self, project: Project) -> None: # When diff --git a/tests/test_project.py b/tests/test_project.py index 523138c3..d49852ce 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -12,7 +12,7 @@ import easyreflectometry from easyreflectometry.data import DataSet1D from easyreflectometry.fitting import MultiFitter -from easyreflectometry.model import LinearSpline +from easyreflectometry.model import Pointwise from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm @@ -561,7 +561,7 @@ def test_load_experiment(self): assert isinstance(project.experiments[5], DataSet1D) assert project.experiments[5].name == 'Experiment for Model 5' assert project.experiments[5].model == model_5 - assert isinstance(project.models[5].resolution_function, LinearSpline) + assert isinstance(project.models[5].resolution_function, Pointwise) assert isinstance(project.models[4].resolution_function, PercentageFwhm) def test_experimental_data_at_index(self): From c796258e62a986447759c6fa40e9a91e5ee0b98b Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Wed, 23 Jul 2025 15:19:18 +0200 Subject: [PATCH 08/16] ruff on tests --- tests/test_project.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_project.py b/tests/test_project.py index d49852ce..bfa6c233 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -12,10 +12,10 @@ import easyreflectometry from easyreflectometry.data import DataSet1D from easyreflectometry.fitting import MultiFitter -from easyreflectometry.model import Pointwise from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm +from easyreflectometry.model import Pointwise from easyreflectometry.project import Project from easyreflectometry.sample import Material from easyreflectometry.sample import MaterialCollection From c1d9682f50b3ce6bab6c00dbf65a626f27fa3715 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Tue, 29 Jul 2025 13:16:25 +0200 Subject: [PATCH 09/16] fix method name --- docs/src/tutorials/advancedfitting/multi_contrast.ipynb | 4 ++-- src/easyreflectometry/sample/assemblies/surfactant_layer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/tutorials/advancedfitting/multi_contrast.ipynb b/docs/src/tutorials/advancedfitting/multi_contrast.ipynb index 39c6fba5..2c812091 100644 --- a/docs/src/tutorials/advancedfitting/multi_contrast.ipynb +++ b/docs/src/tutorials/advancedfitting/multi_contrast.ipynb @@ -358,8 +358,8 @@ "d83acmw.head_layer.area_per_molecule_parameter.enabled = True\n", "d83acmw.tail_layer.area_per_molecule_parameter.enabled = True\n", "\n", - "d70d2o.constain_multiple_contrast(d13d2o)\n", - "d83acmw.constain_multiple_contrast(d70d2o)" + "d70d2o.constrain_multiple_contrast(d13d2o)\n", + "d83acmw.constrain_multiple_contrast(d70d2o)" ] }, { diff --git a/src/easyreflectometry/sample/assemblies/surfactant_layer.py b/src/easyreflectometry/sample/assemblies/surfactant_layer.py index a7dbe1e3..bc9df230 100644 --- a/src/easyreflectometry/sample/assemblies/surfactant_layer.py +++ b/src/easyreflectometry/sample/assemblies/surfactant_layer.py @@ -180,7 +180,7 @@ def constrain_solvent_roughness(self, solvent_roughness: Parameter): rough = ObjConstraint(solvent_roughness, '', self.tail_layer.roughness) self.tail_layer.roughness.user_constraints['solvent_roughness'] = rough - def constain_multiple_contrast( + def constrain_multiple_contrast( self, another_contrast: SurfactantLayer, head_layer_thickness: bool = True, From 13cb7c5154c2be1aeb256541fc40b516fe5d0e56 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Thu, 31 Jul 2025 13:04:18 +0200 Subject: [PATCH 10/16] datasets now have proper names, instead of `R_0` and `Qz_0`. --- src/easyreflectometry/data/__init__.py | 2 + src/easyreflectometry/data/measurement.py | 53 ++++++++++++++++++++--- tests/test_data.py | 28 +++++++----- 3 files changed, 65 insertions(+), 18 deletions(-) diff --git a/src/easyreflectometry/data/__init__.py b/src/easyreflectometry/data/__init__.py index 26c3f270..91aab465 100644 --- a/src/easyreflectometry/data/__init__.py +++ b/src/easyreflectometry/data/__init__.py @@ -2,10 +2,12 @@ from .data_store import ProjectData from .measurement import load from .measurement import load_as_dataset +from .measurement import merge_datagroups __all__ = [ "load", "load_as_dataset", + "merge_datagroups", "ProjectData", "DataSet1D", ] diff --git a/src/easyreflectometry/data/measurement.py b/src/easyreflectometry/data/measurement.py index 5aee6a3c..554e6e4f 100644 --- a/src/easyreflectometry/data/measurement.py +++ b/src/easyreflectometry/data/measurement.py @@ -1,5 +1,6 @@ __author__ = 'github.com/arm61' +import os from typing import TextIO from typing import Union @@ -25,11 +26,16 @@ def load(fname: Union[TextIO, str]) -> sc.DataGroup: def load_as_dataset(fname: Union[TextIO, str]) -> DataSet1D: """Load data from an ORSO .ort file as a DataSet1D.""" data_group = load(fname) + basename = os.path.splitext(os.path.basename(fname))[0] + data_name = 'R_' + basename + coords_name = 'Qz_' + basename + coords_name = list(data_group['coords'].keys())[0] if coords_name not in data_group['coords'] else coords_name + data_name = list(data_group['data'].keys())[0] if data_name not in data_group['data'] else data_name return DataSet1D( - x=data_group['coords']['Qz_0'].values, - y=data_group['data']['R_0'].values, - ye=data_group['data']['R_0'].variances, - xe=data_group['coords']['Qz_0'].variances, + x=data_group['coords'][coords_name].values, + y=data_group['data'][data_name].values, + ye=data_group['data'][data_name].variances, + xe=data_group['coords'][coords_name].variances, ) @@ -86,6 +92,8 @@ def _load_txt(fname: Union[TextIO, str]) -> sc.DataGroup: if ',' in first_line: delimiter = ',' + basename = os.path.splitext(os.path.basename(fname))[0] + try: # First load only the data to check column count data = np.loadtxt(fname, delimiter=delimiter, comments='#') @@ -110,13 +118,44 @@ def _load_txt(fname: Union[TextIO, str]) -> sc.DataGroup: # Re-raise with more descriptive message raise ValueError(f"Failed to load data from {fname}: {str(error)}") from error - data = {'R_0': sc.array(dims=['Qz_0'], values=y, variances=np.square(e))} + data_name = 'R_' + basename + coords_name = 'Qz_' + basename + data = {data_name: sc.array(dims=[coords_name], values=y, variances=np.square(e))} coords = { - data['R_0'].dims[0]: sc.array( - dims=['Qz_0'], + data[data_name].dims[0]: sc.array( + dims=[coords_name], values=x, variances=np.square(xe), unit=sc.Unit('1/angstrom'), ) } return sc.DataGroup(data=data, coords=coords) + +def merge_datagroups(*data_groups: sc.DataGroup) -> sc.DataGroup: + """Merge multiple DataGroups into a single DataGroup.""" + merged_data = {} + merged_coords = {} + merged_attrs = {} + + for group in data_groups: + for key, value in group['data'].items(): + if key not in merged_data: + merged_data[key] = value + else: + merged_data[key] = sc.concatenate([merged_data[key], value]) + + for key, value in group['coords'].items(): + if key not in merged_coords: + merged_coords[key] = value + else: + merged_coords[key] = sc.concatenate([merged_coords[key], value]) + + if 'attrs' not in group: + continue + for key, value in group['attrs'].items(): + if key not in merged_attrs: + merged_attrs[key] = value + else: + merged_attrs[key] = {**merged_attrs[key], **value} + + return sc.DataGroup(data=merged_data, coords=merged_coords, attrs=merged_attrs) diff --git a/tests/test_data.py b/tests/test_data.py index 9fb8e6e4..a974a75a 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -32,18 +32,22 @@ def test_load_with_txt(self): fpath = os.path.join(PATH_STATIC, 'test_example1.txt') er_data = load(fpath) n_data = np.loadtxt(fpath) - assert_almost_equal(er_data['data']['R_0'].values, n_data[:, 1]) - assert_almost_equal(er_data['coords']['Qz_0'].values, n_data[:, 0]) - assert_almost_equal(er_data['data']['R_0'].variances, np.square(n_data[:, 2])) - assert_almost_equal(er_data['coords']['Qz_0'].variances, np.square(n_data[:, 3])) + data_name = 'R_test_example1' + coords_name = 'Qz_test_example1' + assert_almost_equal(er_data['data'][data_name].values, n_data[:, 1]) + assert_almost_equal(er_data['coords'][coords_name].values, n_data[:, 0]) + assert_almost_equal(er_data['data'][data_name].variances, np.square(n_data[:, 2])) + assert_almost_equal(er_data['coords'][coords_name].variances, np.square(n_data[:, 3])) def test_load_with_txt_commas(self): fpath = os.path.join(PATH_STATIC, 'ref_concat_1.txt') er_data = load(fpath) x, y, e = np.loadtxt(fpath, delimiter=',', comments='#', unpack=True) - assert_almost_equal(er_data['data']['R_0'].values, y) - assert_almost_equal(er_data['coords']['Qz_0'].values, x) - assert_almost_equal(er_data['data']['R_0'].variances, np.square(e)) + data_name = 'R_ref_concat_1' + coords_name = 'Qz_ref_concat_1' + assert_almost_equal(er_data['data'][data_name].values, y) + assert_almost_equal(er_data['coords'][coords_name].values, x) + assert_almost_equal(er_data['data'][data_name].variances, np.square(e)) def test_orso1(self): fpath = os.path.join(PATH_STATIC, 'test_example1.ort') @@ -93,7 +97,9 @@ def test_txt(self): fpath = os.path.join(PATH_STATIC, 'test_example1.txt') er_data = _load_txt(fpath) n_data = np.loadtxt(fpath) - assert_almost_equal(er_data['data']['R_0'].values, n_data[:, 1]) - assert_almost_equal(er_data['coords']['Qz_0'].values, n_data[:, 0]) - assert_almost_equal(er_data['data']['R_0'].variances, np.square(n_data[:, 2])) - assert_almost_equal(er_data['coords']['Qz_0'].variances, np.square(n_data[:, 3])) + data_name = 'R_test_example1' + coords_name = 'Qz_test_example1' + assert_almost_equal(er_data['data'][data_name].values, n_data[:, 1]) + assert_almost_equal(er_data['coords'][coords_name].values, n_data[:, 0]) + assert_almost_equal(er_data['data'][data_name].variances, np.square(n_data[:, 2])) + assert_almost_equal(er_data['coords'][coords_name].variances, np.square(n_data[:, 3])) From c813ada60ec972f4e5150ca4faefc71a3e3457cc Mon Sep 17 00:00:00 2001 From: rozyczko Date: Thu, 7 Aug 2025 17:01:11 +0200 Subject: [PATCH 11/16] added color changer, renamed default model, minor fixes --- .../tutorials/simulation/resolution_functions.ipynb | 4 ++-- src/easyreflectometry/fitting.py | 3 ++- src/easyreflectometry/model/model.py | 7 ++++--- src/easyreflectometry/model/model_collection.py | 7 ++++--- src/easyreflectometry/project.py | 13 ++++++++++++- tests/model/test_model.py | 6 +++--- tests/model/test_model_collection.py | 8 ++++---- tests/summary/test_summary.py | 2 +- tests/test_project.py | 4 ++-- 9 files changed, 34 insertions(+), 20 deletions(-) diff --git a/docs/src/tutorials/simulation/resolution_functions.ipynb b/docs/src/tutorials/simulation/resolution_functions.ipynb index 1aca5fab..4fc42ad2 100644 --- a/docs/src/tutorials/simulation/resolution_functions.ipynb +++ b/docs/src/tutorials/simulation/resolution_functions.ipynb @@ -439,7 +439,7 @@ ], "metadata": { "kernelspec": { - "display_name": "era", + "display_name": "erl", "language": "python", "name": "python3" }, @@ -453,7 +453,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.11" + "version": "3.12.10" } }, "nbformat": 4, diff --git a/src/easyreflectometry/fitting.py b/src/easyreflectometry/fitting.py index c7bb75c5..41cf72ba 100644 --- a/src/easyreflectometry/fitting.py +++ b/src/easyreflectometry/fitting.py @@ -50,7 +50,8 @@ def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup: ) sld_profile = self.easy_science_multi_fitter._fit_objects[i].interface.sld_profile(self._models[i].unique_name) new_data[f'SLD_{id}'] = sc.array(dims=[f'z_{id}'], values=sld_profile[1] * 1e-6, unit=sc.Unit('1/angstrom') ** 2) - new_data['attrs'][f'R_{id}_model'] = {'model': sc.scalar(self._models[i].as_dict())} + if 'attrs' in new_data: + new_data['attrs'][f'R_{id}_model'] = {'model': sc.scalar(self._models[i].as_dict())} new_data['coords'][f'z_{id}'] = sc.array( dims=[f'z_{id}'], values=sld_profile[0], unit=(1 / new_data['coords'][f'Qz_{id}'].unit).unit ) diff --git a/src/easyreflectometry/model/model.py b/src/easyreflectometry/model/model.py index 969c955a..32242a3e 100644 --- a/src/easyreflectometry/model/model.py +++ b/src/easyreflectometry/model/model.py @@ -42,6 +42,7 @@ }, } +COLORS =["#0173B2", "#DE8F05", "#029E73", "#D55E00", "#CC78BC", "#CA9161", "#FBAFE4", "#949494", "#ECE133", "#56B4E9"] class Model(BaseObj): """Model is the class that represents the experiment. @@ -60,8 +61,8 @@ def __init__( scale: Union[Parameter, Number, None] = None, background: Union[Parameter, Number, None] = None, resolution_function: Union[ResolutionFunction, None] = None, - name: str = 'EasyModel', - color: str = 'black', + name: str = 'Model', + color: str = COLORS[0], unique_name: Optional[str] = None, interface=None, ): @@ -70,7 +71,7 @@ def __init__( :param sample: The sample being modelled. :param scale: Scaling factor of profile. :param background: Linear background magnitude. - :param name: Name of the model, defaults to 'EasyModel'. + :param name: Name of the model, defaults to 'Model'. :param resolution_function: Resolution function, defaults to PercentageFwhm. :param interface: Calculator interface, defaults to `None`. diff --git a/src/easyreflectometry/model/model_collection.py b/src/easyreflectometry/model/model_collection.py index 7ae3b9f2..7e9435ed 100644 --- a/src/easyreflectometry/model/model_collection.py +++ b/src/easyreflectometry/model/model_collection.py @@ -5,7 +5,7 @@ from typing import Tuple from easyreflectometry.sample.collections.base_collection import BaseCollection - +from easyreflectometry.model.model import COLORS from .model import Model @@ -18,7 +18,7 @@ class ModelCollection(BaseCollection): def __init__( self, *models: Tuple[Model], - name: str = 'EasyModels', + name: str = 'Models', interface=None, unique_name: Optional[str] = None, populate_if_none: bool = True, @@ -41,7 +41,8 @@ def add_model(self, model: Optional[Model] = None): :param model: Model to add. """ if model is None: - model = Model(name='EasyModel added', interface=self.interface) + color = COLORS[len(self) % len(COLORS)] + model = Model(name='Model', interface=self.interface, color=color) self.append(model) def duplicate_model(self, index: int): diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index 66fe90c2..a8d09c92 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -240,9 +240,20 @@ def get_index_d2o(self) -> int: self._materials.add_material(Material(name='D2O', sld=6.36, isld=0.0)) return [material.name for material in self._materials].index('D2O') + def load_new_experiment(self, path: Union[Path, str]) -> None: + new_experiment = load_as_dataset(str(path)) + new_index = len(self._experiments) + new_experiment.name = f'Experiment {new_index}' + model_index = 0 + if new_index < len(self.models): + model_index = new_index + new_experiment.model = self.models[model_index] + self._experiments[new_index] = new_experiment + # self._current_model_index = new_index + def load_experiment_for_model_at_index(self, path: Union[Path, str], index: Optional[int] = 0) -> None: self._experiments[index] = load_as_dataset(str(path)) - self._experiments[index].name = f'Experiment for Model {index}' + self._experiments[index].name = f'Experiment {index}' self._experiments[index].model = self.models[index] self._with_experiments = True diff --git a/tests/model/test_model.py b/tests/model/test_model.py index 45755e25..4c8171a0 100644 --- a/tests/model/test_model.py +++ b/tests/model/test_model.py @@ -30,7 +30,7 @@ class TestModel(unittest.TestCase): def test_default(self): p = Model() - assert_equal(p.name, 'EasyModel') + assert_equal(p.name, 'Model') assert_equal(p.interface, None) assert_equal(p.sample.name, 'EasySample') assert_equal(p.scale.display_name, 'scale') @@ -389,7 +389,7 @@ def test_repr(self): assert ( model.__repr__() - == 'EasyModel:\n scale: 1.0\n background: 1.0e-08\n resolution: 5.0 %\n color: black\n sample:\n EasySample:\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n' # noqa: E501 + == "Model:\n scale: 1.0\n background: 1.0e-08\n resolution: 5.0 %\n color: '#0173B2'\n sample:\n EasySample:\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n" # noqa: E501 ) def test_repr_resolution_function(self): @@ -398,7 +398,7 @@ def test_repr_resolution_function(self): model.resolution_function = resolution_function assert ( model.__repr__() - == 'EasyModel:\n scale: 1.0\n background: 1.0e-08\n resolution: function of Q\n color: black\n sample:\n EasySample:\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n' # noqa: E501 + == "Model:\n scale: 1.0\n background: 1.0e-08\n resolution: function of Q\n color: '#0173B2'\n sample:\n EasySample:\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n" # noqa: E501 ) diff --git a/tests/model/test_model_collection.py b/tests/model/test_model_collection.py index 9ab03475..7b22f12b 100644 --- a/tests/model/test_model_collection.py +++ b/tests/model/test_model_collection.py @@ -10,14 +10,14 @@ def test_default(self): collection = ModelCollection() # Expect - assert collection.name == 'EasyModels' + assert collection.name == 'Models' assert collection.interface is None assert len(collection) == 1 - assert collection[0].name == 'EasyModel' + assert collection[0].name == 'Model' def test_dont_populate(self): p = ModelCollection(populate_if_none=False) - assert p.name == 'EasyModels' + assert p.name == 'Models' assert p.interface is None assert len(p) == 0 @@ -31,7 +31,7 @@ def test_from_pars(self): collection = ModelCollection(model_1, model_2, model_3) # Expect - assert collection.name == 'EasyModels' + assert collection.name == 'Models' assert collection.interface is None assert len(collection) == 3 assert collection[0].name == 'Model1' diff --git a/tests/summary/test_summary.py b/tests/summary/test_summary.py index fa018b02..bcd898e1 100644 --- a/tests/summary/test_summary.py +++ b/tests/summary/test_summary.py @@ -129,7 +129,7 @@ def test_experiments_section(self, project: Project) -> None: html = summary._experiments_section() # Expect - assert 'Experiment for Model 0' in html + assert 'Experiment 0' in html assert 'No. of data points' in html assert '408' in html assert 'Resolution function' in html diff --git a/tests/test_project.py b/tests/test_project.py index bfa6c233..d9202de8 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -559,7 +559,7 @@ def test_load_experiment(self): # Expect assert list(project.experiments.keys()) == [5] assert isinstance(project.experiments[5], DataSet1D) - assert project.experiments[5].name == 'Experiment for Model 5' + assert project.experiments[5].name == 'Experiment 5' assert project.experiments[5].model == model_5 assert isinstance(project.models[5].resolution_function, Pointwise) assert isinstance(project.models[4].resolution_function, PercentageFwhm) @@ -575,7 +575,7 @@ def test_experimental_data_at_index(self): data = project.experimental_data_for_model_at_index() # Expect - assert data.name == 'Experiment for Model 0' + assert data.name == 'Experiment 0' assert data.is_experiment assert isinstance(data, DataSet1D) assert len(data.x) == 408 From a4a1b9d9d32d59a61e968d433e7e32ed34460827 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Fri, 8 Aug 2025 08:53:02 +0200 Subject: [PATCH 12/16] fix ruff, fix test name --- src/easyreflectometry/model/model_collection.py | 3 ++- tests/sample/assemblies/test_surfactant_layer.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/easyreflectometry/model/model_collection.py b/src/easyreflectometry/model/model_collection.py index 7e9435ed..02c5e5db 100644 --- a/src/easyreflectometry/model/model_collection.py +++ b/src/easyreflectometry/model/model_collection.py @@ -4,8 +4,9 @@ from typing import Optional from typing import Tuple -from easyreflectometry.sample.collections.base_collection import BaseCollection from easyreflectometry.model.model import COLORS +from easyreflectometry.sample.collections.base_collection import BaseCollection + from .model import Model diff --git a/tests/sample/assemblies/test_surfactant_layer.py b/tests/sample/assemblies/test_surfactant_layer.py index b6b25243..8b47c34e 100644 --- a/tests/sample/assemblies/test_surfactant_layer.py +++ b/tests/sample/assemblies/test_surfactant_layer.py @@ -81,7 +81,7 @@ def test_conformal_roughness(self): assert p.tail_layer.roughness.value == 4 assert p.head_layer.roughness.value == 4 - def test_constain_solvent_roughness(self): + def test_constrain_solvent_roughness(self): p = SurfactantLayer() layer = Layer() p.tail_layer.roughness.value = 2 From 0dc62dcec5e6a7dc7ace2d6f9997cc922367cd3c Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Fri, 8 Aug 2025 14:32:52 +0200 Subject: [PATCH 13/16] added experiment index --- src/easyreflectometry/project.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index a8d09c92..1ec3a77c 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -56,6 +56,7 @@ def __init__(self): self._current_assembly_index = 0 self._current_layer_index = 0 self._fitter_model_index = None + self._current_experiment_index = 0 # Project flags self._created = False @@ -155,6 +156,19 @@ def current_layer_index(self, value: int) -> None: if self._current_layer_index != value: self._current_layer_index = value + @property + def current_experiment_index(self) -> Optional[int]: + return self._current_experiment_index + + @current_experiment_index.setter + def current_experiment_index(self, value: int) -> None: + if value < 0 or value >= len(self._experiments): + raise ValueError(f'Index {value} out of range') + if self._current_experiment_index != value: + self._current_experiment_index = value + # Resetting the model index to 0 when changing the experiment + #self.current_model_index = 0 + @property def created(self) -> bool: return self._created From 86219e6893960fd7327a55aba92d5c8e143bf52b Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Fri, 8 Aug 2025 14:39:16 +0200 Subject: [PATCH 14/16] added tests --- tests/test_project.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/test_project.py b/tests/test_project.py index d9202de8..6e0172a8 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -634,3 +634,39 @@ def test_parameters(self): # Expect assert len(parameters) == 14 assert isinstance(parameters[0], Parameter) + + def test_current_experiment_index_getter_and_setter(self): + project = Project() + # Default value should be 0 + assert project.current_experiment_index == 0 + + # Add two experiments to allow setting index 1 + project._experiments[0] = DataSet1D(name="exp0", x=[], y=[], ye=[], xe=[], model=None) + project._experiments[1] = DataSet1D(name="exp1", x=[], y=[], ye=[], xe=[], model=None) + + # Set to 1 (valid) + project.current_experiment_index = 1 + assert project.current_experiment_index == 1 + + # Set to 0 (valid) + project.current_experiment_index = 0 + assert project.current_experiment_index == 0 + + def test_current_experiment_index_setter_out_of_range(self): + project = Project() + # Add one experiment + project._experiments[0] = DataSet1D(name="exp0", x=[], y=[], ye=[], xe=[], model=None) + + # Negative index should raise + try: + project.current_experiment_index = -1 + assert False, "Expected ValueError for negative index" + except ValueError: + pass + + # Index >= len(_experiments) should raise + try: + project.current_experiment_index = 1 + assert False, "Expected ValueError for out-of-range index" + except ValueError: + pass From aed1629883ef4cb1e8ccc6151df36d6e40d6a288 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Fri, 15 Aug 2025 11:11:55 +0200 Subject: [PATCH 15/16] fix issue with multiple experiments and a single model --- src/easyreflectometry/summary/summary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/easyreflectometry/summary/summary.py b/src/easyreflectometry/summary/summary.py index 4ac259e0..23c3261c 100644 --- a/src/easyreflectometry/summary/summary.py +++ b/src/easyreflectometry/summary/summary.py @@ -143,9 +143,9 @@ def _experiments_section(self) -> str: for idx, experiment in self._project.experiments.items(): experiment_name = experiment.name num_data_points = len(experiment.x) - resolution_function = self._project.models[idx].resolution_function.as_dict()['smearing'] + resolution_function = experiment.model.resolution_function.as_dict()['smearing'] if resolution_function == 'PercentageFwhm': - precentage = self._project.models[idx].resolution_function.as_dict()['constant'] + precentage = experiment.model.resolution_function.as_dict()['constant'] resolution_function = f'{resolution_function} {precentage}%' range_min = min(experiment.y) range_max = max(experiment.y) From c73638bf96fe039796b799b63a06b3967e460fe9 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Tue, 19 Aug 2025 08:41:29 +0200 Subject: [PATCH 16/16] added a helper method --- src/easyreflectometry/project.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index 1ec3a77c..e4846cef 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -79,6 +79,12 @@ def parameters(self) -> List[Parameter]: parameters.append(vertice_obj) return parameters + @property + def enabled_parameters(self) -> List[Parameter]: + parameters = self.parameters + # Only include enabled parameters + return [param for param in parameters if param.enabled] + @property def q_min(self): if self._q_min is None: