diff --git a/docs/src/tutorials/advancedfitting/multi_contrast.ipynb b/docs/src/tutorials/advancedfitting/multi_contrast.ipynb index 39c6fba5..2c812091 100644 --- a/docs/src/tutorials/advancedfitting/multi_contrast.ipynb +++ b/docs/src/tutorials/advancedfitting/multi_contrast.ipynb @@ -358,8 +358,8 @@ "d83acmw.head_layer.area_per_molecule_parameter.enabled = True\n", "d83acmw.tail_layer.area_per_molecule_parameter.enabled = True\n", "\n", - "d70d2o.constain_multiple_contrast(d13d2o)\n", - "d83acmw.constain_multiple_contrast(d70d2o)" + "d70d2o.constrain_multiple_contrast(d13d2o)\n", + "d83acmw.constrain_multiple_contrast(d70d2o)" ] }, { diff --git a/docs/src/tutorials/simulation/resolution_functions.ipynb b/docs/src/tutorials/simulation/resolution_functions.ipynb index 237963e7..4fc42ad2 100644 --- a/docs/src/tutorials/simulation/resolution_functions.ipynb +++ b/docs/src/tutorials/simulation/resolution_functions.ipynb @@ -46,6 +46,7 @@ "from easyreflectometry.model import Model\n", "from easyreflectometry.model import LinearSpline\n", "from easyreflectometry.model import PercentageFwhm\n", + "from easyreflectometry.model import Pointwise\n", "from easyreflectometry.sample import Layer\n", "from easyreflectometry.sample import Material\n", "from easyreflectometry.sample import Multilayer\n", @@ -115,6 +116,16 @@ "dict_reference['10'] = load(file_path_10)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5f65ed7", + "metadata": {}, + "outputs": [], + "source": [ + "dict_reference['0']" + ] + }, { "cell_type": "markdown", "id": "1ab3a164-62c8-4bd3-b0d8-e6f22c83dc74", @@ -251,9 +262,15 @@ "id": "defd6dd5-c618-4af6-a5c7-17532207f0a0", "metadata": {}, "source": [ - "## Resolution functions\n", - "\n", - "We now define the different resoultion functions. " + "## Resolution functions " + ] + }, + { + "cell_type": "markdown", + "id": "c9d903db", + "metadata": {}, + "source": [ + "We can now define the different resoultion functions. " ] }, { @@ -376,11 +393,53 @@ "plt.yscale('log')\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43881642", + "metadata": {}, + "outputs": [], + "source": [ + "key = '1'\n", + "reference_coords = dict_reference[key]['coords']['Qz_0'].values\n", + "reference_variances = dict_reference[key]['coords']['Qz_0'].variances\n", + "reference_data = dict_reference[key]['data']['R_0'].values\n", + "model_coords = np.linspace(\n", + " start=min(reference_coords),\n", + " stop=max(reference_coords),\n", + " num=1000,\n", + ")\n", + "\n", + "model.resolution_function = resolution_function_dict[key]\n", + "model_data = model.interface().reflectity_profile(\n", + " model_coords,\n", + " model.unique_name,\n", + ")\n", + "plt.plot(model_coords, model_data, 'k-', label=f'Variable', linewidth=5)\n", + "data_points = []\n", + "data_points.append(reference_coords) # Qz\n", + "data_points.append(reference_data) # R\n", + "data_points.append(reference_variances) # sQz\n", + "model.resolution_function = Pointwise(q_data_points=data_points)\n", + "model_data = model.interface().reflectity_profile(\n", + " model_coords,\n", + " model.unique_name,\n", + ")\n", + "plt.plot(model_coords, model_data, 'r-', label=f'Pointwise')\n", + "\n", + "ax = plt.gca()\n", + "ax.set_xlim([-0.01, 0.45])\n", + "ax.set_ylim([1e-10, 2.5])\n", + "plt.legend()\n", + "plt.yscale('log')\n", + "plt.show()" + ] } ], "metadata": { "kernelspec": { - "display_name": "easyref", + "display_name": "erl", "language": "python", "name": "python3" }, @@ -394,7 +453,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.9" + "version": "3.12.10" } }, "nbformat": 4, diff --git a/src/easyreflectometry/data/__init__.py b/src/easyreflectometry/data/__init__.py index 26c3f270..91aab465 100644 --- a/src/easyreflectometry/data/__init__.py +++ b/src/easyreflectometry/data/__init__.py @@ -2,10 +2,12 @@ from .data_store import ProjectData from .measurement import load from .measurement import load_as_dataset +from .measurement import merge_datagroups __all__ = [ "load", "load_as_dataset", + "merge_datagroups", "ProjectData", "DataSet1D", ] diff --git a/src/easyreflectometry/data/measurement.py b/src/easyreflectometry/data/measurement.py index 5aee6a3c..554e6e4f 100644 --- a/src/easyreflectometry/data/measurement.py +++ b/src/easyreflectometry/data/measurement.py @@ -1,5 +1,6 @@ __author__ = 'github.com/arm61' +import os from typing import TextIO from typing import Union @@ -25,11 +26,16 @@ def load(fname: Union[TextIO, str]) -> sc.DataGroup: def load_as_dataset(fname: Union[TextIO, str]) -> DataSet1D: """Load data from an ORSO .ort file as a DataSet1D.""" data_group = load(fname) + basename = os.path.splitext(os.path.basename(fname))[0] + data_name = 'R_' + basename + coords_name = 'Qz_' + basename + coords_name = list(data_group['coords'].keys())[0] if coords_name not in data_group['coords'] else coords_name + data_name = list(data_group['data'].keys())[0] if data_name not in data_group['data'] else data_name return DataSet1D( - x=data_group['coords']['Qz_0'].values, - y=data_group['data']['R_0'].values, - ye=data_group['data']['R_0'].variances, - xe=data_group['coords']['Qz_0'].variances, + x=data_group['coords'][coords_name].values, + y=data_group['data'][data_name].values, + ye=data_group['data'][data_name].variances, + xe=data_group['coords'][coords_name].variances, ) @@ -86,6 +92,8 @@ def _load_txt(fname: Union[TextIO, str]) -> sc.DataGroup: if ',' in first_line: delimiter = ',' + basename = os.path.splitext(os.path.basename(fname))[0] + try: # First load only the data to check column count data = np.loadtxt(fname, delimiter=delimiter, comments='#') @@ -110,13 +118,44 @@ def _load_txt(fname: Union[TextIO, str]) -> sc.DataGroup: # Re-raise with more descriptive message raise ValueError(f"Failed to load data from {fname}: {str(error)}") from error - data = {'R_0': sc.array(dims=['Qz_0'], values=y, variances=np.square(e))} + data_name = 'R_' + basename + coords_name = 'Qz_' + basename + data = {data_name: sc.array(dims=[coords_name], values=y, variances=np.square(e))} coords = { - data['R_0'].dims[0]: sc.array( - dims=['Qz_0'], + data[data_name].dims[0]: sc.array( + dims=[coords_name], values=x, variances=np.square(xe), unit=sc.Unit('1/angstrom'), ) } return sc.DataGroup(data=data, coords=coords) + +def merge_datagroups(*data_groups: sc.DataGroup) -> sc.DataGroup: + """Merge multiple DataGroups into a single DataGroup.""" + merged_data = {} + merged_coords = {} + merged_attrs = {} + + for group in data_groups: + for key, value in group['data'].items(): + if key not in merged_data: + merged_data[key] = value + else: + merged_data[key] = sc.concatenate([merged_data[key], value]) + + for key, value in group['coords'].items(): + if key not in merged_coords: + merged_coords[key] = value + else: + merged_coords[key] = sc.concatenate([merged_coords[key], value]) + + if 'attrs' not in group: + continue + for key, value in group['attrs'].items(): + if key not in merged_attrs: + merged_attrs[key] = value + else: + merged_attrs[key] = {**merged_attrs[key], **value} + + return sc.DataGroup(data=merged_data, coords=merged_coords, attrs=merged_attrs) diff --git a/src/easyreflectometry/fitting.py b/src/easyreflectometry/fitting.py index c7bb75c5..41cf72ba 100644 --- a/src/easyreflectometry/fitting.py +++ b/src/easyreflectometry/fitting.py @@ -50,7 +50,8 @@ def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup: ) sld_profile = self.easy_science_multi_fitter._fit_objects[i].interface.sld_profile(self._models[i].unique_name) new_data[f'SLD_{id}'] = sc.array(dims=[f'z_{id}'], values=sld_profile[1] * 1e-6, unit=sc.Unit('1/angstrom') ** 2) - new_data['attrs'][f'R_{id}_model'] = {'model': sc.scalar(self._models[i].as_dict())} + if 'attrs' in new_data: + new_data['attrs'][f'R_{id}_model'] = {'model': sc.scalar(self._models[i].as_dict())} new_data['coords'][f'z_{id}'] = sc.array( dims=[f'z_{id}'], values=sld_profile[0], unit=(1 / new_data['coords'][f'Qz_{id}'].unit).unit ) diff --git a/src/easyreflectometry/model/__init__.py b/src/easyreflectometry/model/__init__.py index baa1aec4..b12b504e 100644 --- a/src/easyreflectometry/model/__init__.py +++ b/src/easyreflectometry/model/__init__.py @@ -2,11 +2,13 @@ from .model_collection import ModelCollection from .resolution_functions import LinearSpline from .resolution_functions import PercentageFwhm +from .resolution_functions import Pointwise from .resolution_functions import ResolutionFunction __all__ = ( "LinearSpline", "PercentageFwhm", + "Pointwise", "ResolutionFunction", "Model", "ModelCollection", diff --git a/src/easyreflectometry/model/model.py b/src/easyreflectometry/model/model.py index 969c955a..32242a3e 100644 --- a/src/easyreflectometry/model/model.py +++ b/src/easyreflectometry/model/model.py @@ -42,6 +42,7 @@ }, } +COLORS =["#0173B2", "#DE8F05", "#029E73", "#D55E00", "#CC78BC", "#CA9161", "#FBAFE4", "#949494", "#ECE133", "#56B4E9"] class Model(BaseObj): """Model is the class that represents the experiment. @@ -60,8 +61,8 @@ def __init__( scale: Union[Parameter, Number, None] = None, background: Union[Parameter, Number, None] = None, resolution_function: Union[ResolutionFunction, None] = None, - name: str = 'EasyModel', - color: str = 'black', + name: str = 'Model', + color: str = COLORS[0], unique_name: Optional[str] = None, interface=None, ): @@ -70,7 +71,7 @@ def __init__( :param sample: The sample being modelled. :param scale: Scaling factor of profile. :param background: Linear background magnitude. - :param name: Name of the model, defaults to 'EasyModel'. + :param name: Name of the model, defaults to 'Model'. :param resolution_function: Resolution function, defaults to PercentageFwhm. :param interface: Calculator interface, defaults to `None`. diff --git a/src/easyreflectometry/model/model_collection.py b/src/easyreflectometry/model/model_collection.py index 7ae3b9f2..02c5e5db 100644 --- a/src/easyreflectometry/model/model_collection.py +++ b/src/easyreflectometry/model/model_collection.py @@ -4,6 +4,7 @@ from typing import Optional from typing import Tuple +from easyreflectometry.model.model import COLORS from easyreflectometry.sample.collections.base_collection import BaseCollection from .model import Model @@ -18,7 +19,7 @@ class ModelCollection(BaseCollection): def __init__( self, *models: Tuple[Model], - name: str = 'EasyModels', + name: str = 'Models', interface=None, unique_name: Optional[str] = None, populate_if_none: bool = True, @@ -41,7 +42,8 @@ def add_model(self, model: Optional[Model] = None): :param model: Model to add. """ if model is None: - model = Model(name='EasyModel added', interface=self.interface) + color = COLORS[len(self) % len(COLORS)] + model = Model(name='Model', interface=self.interface, color=color) self.append(model) def duplicate_model(self, index: int): diff --git a/src/easyreflectometry/model/resolution_functions.py b/src/easyreflectometry/model/resolution_functions.py index 3a352fc5..736aec21 100644 --- a/src/easyreflectometry/model/resolution_functions.py +++ b/src/easyreflectometry/model/resolution_functions.py @@ -30,6 +30,8 @@ def from_dict(cls, data: dict) -> ResolutionFunction: return PercentageFwhm(data['constant']) if data['smearing'] == 'LinearSpline': return LinearSpline(data['q_data_points'], data['fwhm_values']) + if data['smearing'] == 'Pointwise': + return Pointwise([data['q_data_points'], data['R_data_points'], data['sQz_data_points']]) raise ValueError('Unknown resolution function type') @@ -60,3 +62,60 @@ def as_dict( self, skip: Optional[List[str]] = None ) -> dict[str, str]: # skip is kept for consistency of the as_dict signature return {'smearing': 'LinearSpline', 'q_data_points': list(self.q_data_points), 'fwhm_values': list(self.fwhm_values)} + +# add pointwise smearing funtion +class Pointwise(ResolutionFunction): + def __init__(self, q_data_points: list[np.ndarray]): + self.q_data_points = q_data_points + self.q = None + + def smearing(self, q: Union[np.ndarray, float] = None) -> np.ndarray: + + Qz = self.q_data_points[0] + R = self.q_data_points[1] + sQz = self.q_data_points[2] + if q is None: + q = self.q_data_points[0] + self.q = q + sQzs = np.sqrt(sQz) + if isinstance(Qz, float): + Qz = np.array(Qz) + + smeared = self.apply_smooth_smearing(Qz, R, sQzs) + return smeared + + def as_dict( + self, skip: Optional[List[str]] = None + ) -> dict[str, str]: # skip is kept for consistency of the as_dict signature + return {'smearing': 'Pointwise', + 'q_data_points': list(self.q_data_points[0]), + 'R_data_points': list(self.q_data_points[1]), + 'sQz_data_points': list(self.q_data_points[2])} + + def gaussian_smearing(self, qt, Qz, R, sQz): + weights = np.exp(-0.5 * ((qt - Qz) / sQz) ** 2) + if np.sum(weights) == 0 or not np.isfinite(np.sum(weights)): + return np.sum(R) + weights /= (sQz * np.sqrt(2 * np.pi)) + return np.sum(R * weights) / np.sum(weights) + + + def apply_smooth_smearing(self, Qz, R, sQzs): + """ + Apply smooth resolution smearing using convolution with Gaussian kernel. + """ + if self.q is None: + R_smeared = np.zeros_like(Qz) + else: + R_smeared = np.zeros_like(self.q) + + if not isinstance(Qz, np.ndarray): + Qz = np.array(Qz) + if not isinstance(R, np.ndarray): + R = np.array(R) + R_smeared = np.zeros_like(self.q) + + for i, qt in enumerate(self.q): + R_smeared[i] = self.gaussian_smearing(qt, Qz, R, sQzs) + + return R_smeared diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index de769ab8..e4846cef 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -18,10 +18,10 @@ from easyreflectometry.data import DataSet1D from easyreflectometry.data import load_as_dataset from easyreflectometry.fitting import MultiFitter -from easyreflectometry.model import LinearSpline from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm +from easyreflectometry.model import Pointwise from easyreflectometry.sample import Layer from easyreflectometry.sample import Material from easyreflectometry.sample import MaterialCollection @@ -56,6 +56,7 @@ def __init__(self): self._current_assembly_index = 0 self._current_layer_index = 0 self._fitter_model_index = None + self._current_experiment_index = 0 # Project flags self._created = False @@ -78,6 +79,12 @@ def parameters(self) -> List[Parameter]: parameters.append(vertice_obj) return parameters + @property + def enabled_parameters(self) -> List[Parameter]: + parameters = self.parameters + # Only include enabled parameters + return [param for param in parameters if param.enabled] + @property def q_min(self): if self._q_min is None: @@ -155,6 +162,19 @@ def current_layer_index(self, value: int) -> None: if self._current_layer_index != value: self._current_layer_index = value + @property + def current_experiment_index(self) -> Optional[int]: + return self._current_experiment_index + + @current_experiment_index.setter + def current_experiment_index(self, value: int) -> None: + if value < 0 or value >= len(self._experiments): + raise ValueError(f'Index {value} out of range') + if self._current_experiment_index != value: + self._current_experiment_index = value + # Resetting the model index to 0 when changing the experiment + #self.current_model_index = 0 + @property def created(self) -> bool: return self._created @@ -240,19 +260,35 @@ def get_index_d2o(self) -> int: self._materials.add_material(Material(name='D2O', sld=6.36, isld=0.0)) return [material.name for material in self._materials].index('D2O') + def load_new_experiment(self, path: Union[Path, str]) -> None: + new_experiment = load_as_dataset(str(path)) + new_index = len(self._experiments) + new_experiment.name = f'Experiment {new_index}' + model_index = 0 + if new_index < len(self.models): + model_index = new_index + new_experiment.model = self.models[model_index] + self._experiments[new_index] = new_experiment + # self._current_model_index = new_index + def load_experiment_for_model_at_index(self, path: Union[Path, str], index: Optional[int] = 0) -> None: self._experiments[index] = load_as_dataset(str(path)) - self._experiments[index].name = f'Experiment for Model {index}' + self._experiments[index].name = f'Experiment {index}' self._experiments[index].model = self.models[index] self._with_experiments = True # Set the resolution function if variance data is present if sum(self._experiments[index].ye) != 0: - resolution_function = LinearSpline( - q_data_points=self._experiments[index].y, - fwhm_values=np.sqrt(self._experiments[index].ye), - ) + q = self._experiments[index].x + reflectivity = self._experiments[index].y + q_error = self._experiments[index].xe + resolution_function = Pointwise( + q_data_points=[q, reflectivity, q_error]) + # resolution_function = LinearSpline( + # q_data_points=self._experiments[index].y, + # fwhm_values=np.sqrt(self._experiments[index].ye), + # ) self._models[index].resolution_function = resolution_function def sld_data_for_model_at_index(self, index: int = 0) -> DataSet1D: diff --git a/src/easyreflectometry/sample/assemblies/surfactant_layer.py b/src/easyreflectometry/sample/assemblies/surfactant_layer.py index a7dbe1e3..bc9df230 100644 --- a/src/easyreflectometry/sample/assemblies/surfactant_layer.py +++ b/src/easyreflectometry/sample/assemblies/surfactant_layer.py @@ -180,7 +180,7 @@ def constrain_solvent_roughness(self, solvent_roughness: Parameter): rough = ObjConstraint(solvent_roughness, '', self.tail_layer.roughness) self.tail_layer.roughness.user_constraints['solvent_roughness'] = rough - def constain_multiple_contrast( + def constrain_multiple_contrast( self, another_contrast: SurfactantLayer, head_layer_thickness: bool = True, diff --git a/src/easyreflectometry/summary/summary.py b/src/easyreflectometry/summary/summary.py index 4ac259e0..23c3261c 100644 --- a/src/easyreflectometry/summary/summary.py +++ b/src/easyreflectometry/summary/summary.py @@ -143,9 +143,9 @@ def _experiments_section(self) -> str: for idx, experiment in self._project.experiments.items(): experiment_name = experiment.name num_data_points = len(experiment.x) - resolution_function = self._project.models[idx].resolution_function.as_dict()['smearing'] + resolution_function = experiment.model.resolution_function.as_dict()['smearing'] if resolution_function == 'PercentageFwhm': - precentage = self._project.models[idx].resolution_function.as_dict()['constant'] + precentage = experiment.model.resolution_function.as_dict()['constant'] resolution_function = f'{resolution_function} {precentage}%' range_min = min(experiment.y) range_max = max(experiment.y) diff --git a/tests/model/test_model.py b/tests/model/test_model.py index 45755e25..4c8171a0 100644 --- a/tests/model/test_model.py +++ b/tests/model/test_model.py @@ -30,7 +30,7 @@ class TestModel(unittest.TestCase): def test_default(self): p = Model() - assert_equal(p.name, 'EasyModel') + assert_equal(p.name, 'Model') assert_equal(p.interface, None) assert_equal(p.sample.name, 'EasySample') assert_equal(p.scale.display_name, 'scale') @@ -389,7 +389,7 @@ def test_repr(self): assert ( model.__repr__() - == 'EasyModel:\n scale: 1.0\n background: 1.0e-08\n resolution: 5.0 %\n color: black\n sample:\n EasySample:\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n' # noqa: E501 + == "Model:\n scale: 1.0\n background: 1.0e-08\n resolution: 5.0 %\n color: '#0173B2'\n sample:\n EasySample:\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n" # noqa: E501 ) def test_repr_resolution_function(self): @@ -398,7 +398,7 @@ def test_repr_resolution_function(self): model.resolution_function = resolution_function assert ( model.__repr__() - == 'EasyModel:\n scale: 1.0\n background: 1.0e-08\n resolution: function of Q\n color: black\n sample:\n EasySample:\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n' # noqa: E501 + == "Model:\n scale: 1.0\n background: 1.0e-08\n resolution: function of Q\n color: '#0173B2'\n sample:\n EasySample:\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n - EasyMultilayer:\n EasyLayerCollection:\n - EasyLayer:\n material:\n EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n thickness: 10.000 Å\n roughness: 3.300 Å\n" # noqa: E501 ) diff --git a/tests/model/test_model_collection.py b/tests/model/test_model_collection.py index 9ab03475..7b22f12b 100644 --- a/tests/model/test_model_collection.py +++ b/tests/model/test_model_collection.py @@ -10,14 +10,14 @@ def test_default(self): collection = ModelCollection() # Expect - assert collection.name == 'EasyModels' + assert collection.name == 'Models' assert collection.interface is None assert len(collection) == 1 - assert collection[0].name == 'EasyModel' + assert collection[0].name == 'Model' def test_dont_populate(self): p = ModelCollection(populate_if_none=False) - assert p.name == 'EasyModels' + assert p.name == 'Models' assert p.interface is None assert len(p) == 0 @@ -31,7 +31,7 @@ def test_from_pars(self): collection = ModelCollection(model_1, model_2, model_3) # Expect - assert collection.name == 'EasyModels' + assert collection.name == 'Models' assert collection.interface is None assert len(collection) == 3 assert collection[0].name == 'Model1' diff --git a/tests/model/test_resolution_functions.py b/tests/model/test_resolution_functions.py index f2963c48..b5b1c6d7 100644 --- a/tests/model/test_resolution_functions.py +++ b/tests/model/test_resolution_functions.py @@ -5,6 +5,7 @@ from easyreflectometry.model.resolution_functions import DEFAULT_RESOLUTION_FWHM_PERCENTAGE from easyreflectometry.model.resolution_functions import LinearSpline from easyreflectometry.model.resolution_functions import PercentageFwhm +from easyreflectometry.model.resolution_functions import Pointwise from easyreflectometry.model.resolution_functions import ResolutionFunction @@ -75,3 +76,36 @@ def test_dict_round_trip(self): # Expect assert all(resolution_function.smearing([0, 2.5]) == expected_resolution_function.smearing([0, 2.5])) + +class TestPointwise(unittest.TestCase): + + data_points = [] + data_points.append([0.1, 0.2, 0.3, 0.4, 0.5]) # Qz + data_points.append([1.1, 2.2, 3.3, 4.4, 5.5]) # R + data_points.append([0.03, 0.04, 0.05, 0.06, 0.07]) # sQz + def test_constructor(self): + + # When + resolution_function = Pointwise(q_data_points=self.data_points) + + # Then Expect + assert np.allclose(np.array(resolution_function.smearing()), + np.array([2.51664683, 2.84038734, 3.2460762 , 3.6796519 , 4.07869271])) + + def test_as_dict(self): + # When + resolution_function = Pointwise(q_data_points=self.data_points) + + # Then Expect + assert resolution_function.as_dict(), {'smearing': 'Pointwise', 'q_data_points': [0, 10]} + + def test_dict_round_trip(self): + # When + expected_resolution_function = Pointwise(q_data_points=self.data_points) + res_dict = expected_resolution_function.as_dict() + + # Then + resolution_function = ResolutionFunction.from_dict(res_dict) + + # Expect + assert all(resolution_function.smearing() == expected_resolution_function.smearing()) diff --git a/tests/sample/assemblies/test_surfactant_layer.py b/tests/sample/assemblies/test_surfactant_layer.py index b6b25243..8b47c34e 100644 --- a/tests/sample/assemblies/test_surfactant_layer.py +++ b/tests/sample/assemblies/test_surfactant_layer.py @@ -81,7 +81,7 @@ def test_conformal_roughness(self): assert p.tail_layer.roughness.value == 4 assert p.head_layer.roughness.value == 4 - def test_constain_solvent_roughness(self): + def test_constrain_solvent_roughness(self): p = SurfactantLayer() layer = Layer() p.tail_layer.roughness.value = 2 diff --git a/tests/summary/test_summary.py b/tests/summary/test_summary.py index ef8ad80e..bcd898e1 100644 --- a/tests/summary/test_summary.py +++ b/tests/summary/test_summary.py @@ -129,11 +129,11 @@ def test_experiments_section(self, project: Project) -> None: html = summary._experiments_section() # Expect - assert 'Experiment for Model 0' in html + assert 'Experiment 0' in html assert 'No. of data points' in html assert '408' in html assert 'Resolution function' in html - assert 'LinearSpline' in html + assert 'Pointwise' in html def test_experiments_section_percentage_fhwm(self, project: Project) -> None: # When diff --git a/tests/test_data.py b/tests/test_data.py index 9fb8e6e4..a974a75a 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -32,18 +32,22 @@ def test_load_with_txt(self): fpath = os.path.join(PATH_STATIC, 'test_example1.txt') er_data = load(fpath) n_data = np.loadtxt(fpath) - assert_almost_equal(er_data['data']['R_0'].values, n_data[:, 1]) - assert_almost_equal(er_data['coords']['Qz_0'].values, n_data[:, 0]) - assert_almost_equal(er_data['data']['R_0'].variances, np.square(n_data[:, 2])) - assert_almost_equal(er_data['coords']['Qz_0'].variances, np.square(n_data[:, 3])) + data_name = 'R_test_example1' + coords_name = 'Qz_test_example1' + assert_almost_equal(er_data['data'][data_name].values, n_data[:, 1]) + assert_almost_equal(er_data['coords'][coords_name].values, n_data[:, 0]) + assert_almost_equal(er_data['data'][data_name].variances, np.square(n_data[:, 2])) + assert_almost_equal(er_data['coords'][coords_name].variances, np.square(n_data[:, 3])) def test_load_with_txt_commas(self): fpath = os.path.join(PATH_STATIC, 'ref_concat_1.txt') er_data = load(fpath) x, y, e = np.loadtxt(fpath, delimiter=',', comments='#', unpack=True) - assert_almost_equal(er_data['data']['R_0'].values, y) - assert_almost_equal(er_data['coords']['Qz_0'].values, x) - assert_almost_equal(er_data['data']['R_0'].variances, np.square(e)) + data_name = 'R_ref_concat_1' + coords_name = 'Qz_ref_concat_1' + assert_almost_equal(er_data['data'][data_name].values, y) + assert_almost_equal(er_data['coords'][coords_name].values, x) + assert_almost_equal(er_data['data'][data_name].variances, np.square(e)) def test_orso1(self): fpath = os.path.join(PATH_STATIC, 'test_example1.ort') @@ -93,7 +97,9 @@ def test_txt(self): fpath = os.path.join(PATH_STATIC, 'test_example1.txt') er_data = _load_txt(fpath) n_data = np.loadtxt(fpath) - assert_almost_equal(er_data['data']['R_0'].values, n_data[:, 1]) - assert_almost_equal(er_data['coords']['Qz_0'].values, n_data[:, 0]) - assert_almost_equal(er_data['data']['R_0'].variances, np.square(n_data[:, 2])) - assert_almost_equal(er_data['coords']['Qz_0'].variances, np.square(n_data[:, 3])) + data_name = 'R_test_example1' + coords_name = 'Qz_test_example1' + assert_almost_equal(er_data['data'][data_name].values, n_data[:, 1]) + assert_almost_equal(er_data['coords'][coords_name].values, n_data[:, 0]) + assert_almost_equal(er_data['data'][data_name].variances, np.square(n_data[:, 2])) + assert_almost_equal(er_data['coords'][coords_name].variances, np.square(n_data[:, 3])) diff --git a/tests/test_project.py b/tests/test_project.py index 523138c3..6e0172a8 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -12,10 +12,10 @@ import easyreflectometry from easyreflectometry.data import DataSet1D from easyreflectometry.fitting import MultiFitter -from easyreflectometry.model import LinearSpline from easyreflectometry.model import Model from easyreflectometry.model import ModelCollection from easyreflectometry.model import PercentageFwhm +from easyreflectometry.model import Pointwise from easyreflectometry.project import Project from easyreflectometry.sample import Material from easyreflectometry.sample import MaterialCollection @@ -559,9 +559,9 @@ def test_load_experiment(self): # Expect assert list(project.experiments.keys()) == [5] assert isinstance(project.experiments[5], DataSet1D) - assert project.experiments[5].name == 'Experiment for Model 5' + assert project.experiments[5].name == 'Experiment 5' assert project.experiments[5].model == model_5 - assert isinstance(project.models[5].resolution_function, LinearSpline) + assert isinstance(project.models[5].resolution_function, Pointwise) assert isinstance(project.models[4].resolution_function, PercentageFwhm) def test_experimental_data_at_index(self): @@ -575,7 +575,7 @@ def test_experimental_data_at_index(self): data = project.experimental_data_for_model_at_index() # Expect - assert data.name == 'Experiment for Model 0' + assert data.name == 'Experiment 0' assert data.is_experiment assert isinstance(data, DataSet1D) assert len(data.x) == 408 @@ -634,3 +634,39 @@ def test_parameters(self): # Expect assert len(parameters) == 14 assert isinstance(parameters[0], Parameter) + + def test_current_experiment_index_getter_and_setter(self): + project = Project() + # Default value should be 0 + assert project.current_experiment_index == 0 + + # Add two experiments to allow setting index 1 + project._experiments[0] = DataSet1D(name="exp0", x=[], y=[], ye=[], xe=[], model=None) + project._experiments[1] = DataSet1D(name="exp1", x=[], y=[], ye=[], xe=[], model=None) + + # Set to 1 (valid) + project.current_experiment_index = 1 + assert project.current_experiment_index == 1 + + # Set to 0 (valid) + project.current_experiment_index = 0 + assert project.current_experiment_index == 0 + + def test_current_experiment_index_setter_out_of_range(self): + project = Project() + # Add one experiment + project._experiments[0] = DataSet1D(name="exp0", x=[], y=[], ye=[], xe=[], model=None) + + # Negative index should raise + try: + project.current_experiment_index = -1 + assert False, "Expected ValueError for negative index" + except ValueError: + pass + + # Index >= len(_experiments) should raise + try: + project.current_experiment_index = 1 + assert False, "Expected ValueError for out-of-range index" + except ValueError: + pass