From 41c58c5467e4802fa1e7c3641b9507281851538a Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Sat, 23 May 2026 22:24:08 +0200 Subject: [PATCH 1/3] initial commit --- CHANGELOG.md | 33 ++ docs/docs/tutorials/simulation/bilayer.ipynb | 12 +- src/easyreflectometry/model/model.py | 126 ++++--- .../model/model_collection.py | 57 +-- src/easyreflectometry/project.py | 30 +- .../sample/assemblies/base_assembly.py | 26 +- .../sample/assemblies/bilayer.py | 31 +- .../sample/assemblies/gradient_layer.py | 43 ++- .../sample/assemblies/multilayer.py | 8 +- .../sample/assemblies/repeating_multilayer.py | 21 +- .../sample/assemblies/surfactant_layer.py | 28 +- src/easyreflectometry/sample/base_core.py | 251 +++++++++++-- .../sample/collections/base_collection.py | 351 ++++++++++++++---- .../sample/collections/layer_collection.py | 11 +- .../sample/collections/material_collection.py | 5 +- .../sample/collections/sample.py | 41 +- .../sample/elements/layers/layer.py | 49 ++- .../layers/layer_area_per_molecule.py | 193 +++++----- .../sample/elements/materials/material.py | 37 +- .../elements/materials/material_density.py | 107 ++++-- .../elements/materials/material_mixture.py | 210 ++++++----- .../elements/materials/material_solvated.py | 105 +++--- src/easyreflectometry/summary/summary.py | 4 +- tests/model/test_model_collection.py | 4 +- tests/sample/assemblies/test_bilayer.py | 14 +- .../assemblies/test_surfactant_layer.py | 20 +- .../collections/test_base_collection.py | 2 +- .../layers/test_layer_area_per_molecule.py | 75 +++- .../materials/test_material_density.py | 19 + .../materials/test_material_mixture.py | 60 ++- .../materials/test_material_solvated.py | 32 +- tests/test_project.py | 2 + 32 files changed, 1361 insertions(+), 646 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb888bf6..2537591f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,36 @@ +# Unreleased + +Migrated sample / model classes off the deprecated `easyscience.ObjBase` +and `easyscience.CollectionBase` pipeline. + +- `BaseCore` is now built on `ModelBase`; `BaseCollection` on + `EasyList`. `Model`, `Material`, `Layer`, `MaterialMixture`, + `MaterialSolvated`, `LayerAreaPerMolecule`, `Multilayer`, + `RepeatingMultilayer`, `GradientLayer`, `Bilayer`, `SurfactantLayer`, + `BaseAssembly`, `LayerCollection`, `MaterialCollection`, `Sample`, and + `ModelCollection` were all rewritten to use the new bases. +- Properties returning a `Parameter` (`Material.sld`-style) now expose + the `Parameter` object directly across all sample classes, replacing + the inconsistent legacy behaviour where `MaterialMixture.fraction`, + `MaterialSolvated.solvent_fraction`, + `LayerAreaPerMolecule.area_per_molecule`, and + `LayerAreaPerMolecule.solvent_fraction` returned `float`. Read the + value via `.value` (e.g. `material_mixture.fraction.value`). Setters + still accept a float. `MaterialMixture.sld` / `MaterialMixture.isld` + remain `float` — they are derived via constraints, not constructor + arguments. +- `BaseCollection.remove(index)` (the legacy index-based helper) renamed + to `remove_at(index)`. The standard `MutableSequence.remove(value)` is + now inherited unmodified. +- Project files saved by previous versions cannot be read. + `Project.as_dict` writes `file_format=2`; `Project.from_dict` raises a + clear `ValueError` on missing or unsupported markers. +- `model.get_parameters()` / `collection.get_parameters()` still work + (kept as compatibility shims) but new code should use + `get_all_parameters()`. +- No more `DeprecationWarning` from `easyscience.ObjBase` / + `CollectionBase` on construction of any sample / model object. + # Version 1.6.0 (1 May 2026) Add Mighell-based handling of non-positive-variance points in fitting diff --git a/docs/docs/tutorials/simulation/bilayer.ipynb b/docs/docs/tutorials/simulation/bilayer.ipynb index b16b8284..5f37b010 100644 --- a/docs/docs/tutorials/simulation/bilayer.ipynb +++ b/docs/docs/tutorials/simulation/bilayer.ipynb @@ -202,7 +202,7 @@ "# Access key structural parameters\n", "print(f'Head thickness: {bilayer.front_head_layer.thickness.value:.2f} Å')\n", "print(f'Tail thickness: {bilayer.front_tail_layer.thickness.value:.2f} Å')\n", - "print(f'Area per molecule: {bilayer.front_head_layer.area_per_molecule:.2f} Ų')" + "print(f'Area per molecule: {bilayer.front_head_layer.area_per_molecule.value:.2f} Ų')" ] }, { @@ -224,14 +224,14 @@ "source": [ "# Head layers share thickness and area per molecule via constrain_heads=True,\n", "# but solvent fraction is independent and can be set separately for each side.\n", - "print(f'Front head solvent fraction: {bilayer.front_head_layer.solvent_fraction:.2f}')\n", - "print(f'Back head solvent fraction: {bilayer.back_head_layer.solvent_fraction:.2f}')\n", + "print(f'Front head solvent fraction: {bilayer.front_head_layer.solvent_fraction.value:.2f}')\n", + "print(f'Back head solvent fraction: {bilayer.back_head_layer.solvent_fraction.value:.2f}')\n", "\n", "# We can set them independently\n", "bilayer.back_head_layer.solvent_fraction = 0.5\n", "print('\\nAfter setting back head solvent fraction to 0.5:')\n", - "print(f'Front head solvent fraction: {bilayer.front_head_layer.solvent_fraction:.2f}')\n", - "print(f'Back head solvent fraction: {bilayer.back_head_layer.solvent_fraction:.2f}')" + "print(f'Front head solvent fraction: {bilayer.front_head_layer.solvent_fraction.value:.2f}')\n", + "print(f'Back head solvent fraction: {bilayer.back_head_layer.solvent_fraction.value:.2f}')" ] }, { @@ -685,7 +685,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.12" + "version": "3.12.11" } }, "nbformat": 4, diff --git a/src/easyreflectometry/model/model.py b/src/easyreflectometry/model/model.py index b1aedf85..da1396d0 100644 --- a/src/easyreflectometry/model/model.py +++ b/src/easyreflectometry/model/model.py @@ -9,15 +9,14 @@ from typing import Union import numpy as np -from easyscience import ObjBase as BaseObj from easyscience import global_object from easyscience.variable import Parameter from easyreflectometry.limits import apply_default_limits from easyreflectometry.sample import BaseAssembly from easyreflectometry.sample import Sample +from easyreflectometry.sample.base_core import BaseCore from easyreflectometry.utils import get_as_parameter -from easyreflectometry.utils import yaml_dump from .resolution_functions import PercentageFwhm from .resolution_functions import ResolutionFunction @@ -58,18 +57,12 @@ ] -class Model(BaseObj): +class Model(BaseCore): """Model is the class that represents the experiment. It is used to store the information about the experiment and to perform the calculations. """ - # Added in super().__init__ - name: str - sample: Sample - scale: Parameter - background: Parameter - def __init__( self, sample: Union[Sample, None] = None, @@ -115,18 +108,45 @@ def __init__( background = get_as_parameter('background', background, DEFAULTS) self.color = color self._is_default = False + self._resolution_function = resolution_function + + super().__init__(name=name, unique_name=unique_name) + self._sample = sample + self._scale = scale + self._background = background + + # Set interface last — propagates to children via BaseCore.generate_bindings + # and then sets the resolution function on the calculator (see setter). + if interface is not None: + self.interface = interface + + # ----- @property accessors for serialization round-trip ----- + + @property + def sample(self) -> Sample: + return self._sample + + @sample.setter + def sample(self, value: Sample) -> None: + self._sample = value - super().__init__( - name=name, - unique_name=unique_name, - sample=sample, - scale=scale, - background=background, - ) - self.resolution_function = resolution_function + @property + def scale(self) -> Parameter: + return self._scale + + @scale.setter + def scale(self, value: float) -> None: + self._scale.value = value + + @property + def background(self) -> Parameter: + return self._background - # Must be set after resolution function - self.interface = interface + @background.setter + def background(self, value: float) -> None: + self._background.value = value + + # ----- assembly management ----- def add_assemblies(self, *assemblies: list[BaseAssembly]) -> None: """Add assemblies to the model sample. @@ -183,15 +203,11 @@ def is_default(self) -> bool: @is_default.setter def is_default(self, value: bool) -> None: - """Set whether this model is a default placeholder. - - Parameters - ---------- - value : bool - True if the model is a default placeholder. - """ + """Set whether this model is a default placeholder.""" self._is_default = value + # ----- resolution function ----- + @property def resolution_function(self) -> ResolutionFunction: """Return the resolution function.""" @@ -204,21 +220,20 @@ def resolution_function(self, resolution_function: ResolutionFunction) -> None: if self.interface is not None: self.interface().set_resolution_function(self._resolution_function) - @property - def interface(self): - """Get the current interface of the object.""" - return self._interface + # ----- interface (override BaseCore's to add resolution-function side effect) ----- - @interface.setter + @BaseCore.interface.setter def interface(self, new_interface) -> None: - """Set the interface for the model.""" - # From super class - self._interface = new_interface + """Set the interface; runs `generate_bindings` and then refreshes the + calculator's resolution function. + """ + # Call BaseCore.interface.setter for the binding propagation. + BaseCore.interface.fset(self, new_interface) if new_interface is not None: - self.generate_bindings() - self._interface().set_resolution_function(self._resolution_function) + new_interface().set_resolution_function(self._resolution_function) + + # ----- representation ----- - # Representation @property def _dict_repr(self) -> dict[str, dict[str, str]]: """A simplified dict representation.""" @@ -238,24 +253,15 @@ def _dict_repr(self) -> dict[str, dict[str, str]]: } } - def __repr__(self) -> str: - """String representation of the layer.""" - return yaml_dump(self._dict_repr) - - def as_dict(self, skip: Optional[list[str]] = None) -> dict: - """Produces a cleaned dict using a custom as_dict method to skip necessary things. - - The resulting dict matches the parameters in __init__ + # ----- serialization (custom because resolution_function + interface need special handling) ----- - Parameters - ---------- - skip : Optional[list[str]], optional - List of keys to skip. By default, None. - """ + def to_dict(self, skip: Optional[list[str]] = None) -> dict: + """Serialize the model, encoding the resolution function and interface name.""" if skip is None: skip = [] - skip.extend(['sample', 'resolution_function', 'interface']) - this_dict = super().as_dict(skip=skip) + # Sample/resolution_function/interface get bespoke encoding below. + skip_for_super = list(skip) + ['sample', 'resolution_function', 'interface'] + this_dict = super().to_dict(skip=skip_for_super) this_dict['sample'] = self.sample.as_dict(skip=skip) this_dict['resolution_function'] = self.resolution_function.as_dict(skip=skip) if self.interface is None: @@ -264,23 +270,23 @@ def as_dict(self, skip: Optional[list[str]] = None) -> dict: this_dict['interface'] = self.interface().name return this_dict + def as_dict(self, skip: Optional[list[str]] = None) -> dict: + """Compatibility alias for :meth:`to_dict`.""" + return self.to_dict(skip=skip) + def as_orso(self) -> dict: """Convert the model to a dictionary suitable for ORSO.""" - this_dict = self.as_dict() - - return this_dict + return self.as_dict() @classmethod def from_dict(cls, passed_dict: dict) -> Model: """Create a Model from a dictionary.""" - # Causes circular import if imported at the top + # Circular import if hoisted to module-top. from easyreflectometry.calculators import CalculatorFactory this_dict = copy.deepcopy(passed_dict) - resolution_function = ResolutionFunction.from_dict(this_dict['resolution_function']) - del this_dict['resolution_function'] - interface_name = this_dict['interface'] - del this_dict['interface'] + resolution_function = ResolutionFunction.from_dict(this_dict.pop('resolution_function')) + interface_name = this_dict.pop('interface') if interface_name is not None: interface = CalculatorFactory() interface.switch(interface_name) diff --git a/src/easyreflectometry/model/model_collection.py b/src/easyreflectometry/model/model_collection.py index 1a1b7e88..817fc5f0 100644 --- a/src/easyreflectometry/model/model_collection.py +++ b/src/easyreflectometry/model/model_collection.py @@ -3,7 +3,6 @@ from __future__ import annotations -from typing import List from typing import Optional from typing import Tuple @@ -36,20 +35,33 @@ def __init__( models = DEFAULT_ELEMENTS(interface) else: models = [] - # Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict - # Else collisions might occur in global_object.map - self.populate_if_none = False + + # `_next_color_index` must exist before super().__init__ because each + # `append` during construction routes through `_append_internal` → + # `_advance_color_index`, which reads the attribute. self._next_color_index = next_color_index - super().__init__(name, interface, *models, unique_name=unique_name, **kwargs) + super().__init__( + name, + interface, + *models, + unique_name=unique_name, + populate_if_none=False, + **kwargs, + ) color_count = len(COLORS) if color_count == 0: self._next_color_index = 0 - elif self._next_color_index is None: + elif next_color_index is None: self._next_color_index = len(self) % color_count else: - self._next_color_index %= color_count + self._next_color_index = next_color_index % color_count + + @property + def next_color_index(self) -> Optional[int]: + """Index of the next colour to assign — kept around so it round-trips.""" + return self._next_color_index def add_model(self, model: Optional[Model] = None): """Add a model to the collection. @@ -76,28 +88,24 @@ def duplicate_model(self, index: int): duplicate.name = duplicate.name + ' duplicate' self.append(duplicate) - def as_dict(self, skip: List[str] | None = None) -> dict: - """As dict.""" - this_dict = super().as_dict(skip=skip) - this_dict['populate_if_none'] = self.populate_if_none - this_dict['next_color_index'] = self._next_color_index - return this_dict - @classmethod def from_dict(cls, this_dict: dict) -> ModelCollection: """Create an instance of a collection from a dictionary.""" - collection_dict = this_dict.copy() - # We need to call from_dict on the base class to get the models - dict_data = collection_dict.pop('data') + collection_dict = dict(this_dict) + dict_data = collection_dict.pop('data', []) next_color_index = collection_dict.pop('next_color_index', None) - collection = super().from_dict(collection_dict) # type: ModelCollection + # Reconstruct empty collection via EasyList.from_dict (handles + # protected_types and assigns name/unique_name/populate_if_none). + collection = super().from_dict(collection_dict) + # Append each model without advancing the colour index — the saved + # `next_color_index` below is the source of truth. for model_data in dict_data: collection._append_internal(Model.from_dict(model_data), advance=False) - if len(collection) != len(this_dict['data']): - raise ValueError(f'Expected {len(collection)} models, got {len(this_dict["data"])}') + if len(collection) != len(dict_data): + raise ValueError(f'Expected {len(dict_data)} models, got {len(collection)}') color_count = len(COLORS) if color_count == 0: @@ -115,7 +123,14 @@ def append(self, model: Model) -> None: # type: ignore[override] def _append_internal(self, model: Model, advance: bool) -> None: """Append internal.""" - super().append(model) + # Bypass our own `append` override and go straight to EasyList's + # `MutableSequence.append` → `insert` path. Calling `super().append` + # would dispatch back to `ModelCollection.append` because Python + # resolves `append` via MRO from MutableSequence which doesn't + # define it on a class higher than ModelCollection. + from collections.abc import MutableSequence + + MutableSequence.append(self, model) if advance: self._advance_color_index() diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index 5127ad16..5ae272a6 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -90,7 +90,7 @@ def parameters(self) -> List[Parameter]: seen_ids: set[int] = set() if self._models is not None: for model in self._models: - for param in model.get_parameters(): + for param in model.get_all_parameters(): pid = id(param) if pid not in seen_ids: seen_ids.add(pid) @@ -847,9 +847,18 @@ def load_from_json(self, path: Optional[Union[Path, str]] = None): else: print(f'ERROR: File {path} does not exist') + #: Schema version embedded in every serialized project. Bumped from 1 → 2 + #: when the sample/model classes migrated from the legacy + #: ``easyscience.ObjBase``/``CollectionBase`` pipeline to + #: ``ModelBase``/``EasyList``. The on-disk shape of nested objects (Layer, + #: Material, MaterialMixture, MaterialSolvated, LayerAreaPerMolecule, etc.) + #: changed in a way that is not backward-compatible with v1 files. + FILE_FORMAT = 2 + def as_dict(self, include_materials_not_in_model=False): """As dict.""" project_dict = {} + project_dict['file_format'] = self.FILE_FORMAT project_dict['info'] = self._info project_dict['with_experiments'] = self._with_experiments if self._models is not None: @@ -898,6 +907,25 @@ def _as_dict_add_experiments(self, project_dict: dict): def from_dict(self, project_dict: dict): """From dict.""" keys = list(project_dict.keys()) + # Validate file format. v1 files were written by the legacy + # `ObjBase`/`CollectionBase` pipeline; their inner shapes (Layer, + # Material, MaterialMixture, …) are not compatible with the v2 + # `ModelBase`/`EasyList` deserializer. Older files must be re-created. + file_format = project_dict.get('file_format') + if file_format is None: + raise ValueError( + 'This project file predates file_format=2 and cannot be loaded by ' + 'this version of easyreflectometry. The serialization format changed ' + 'when the sample/model classes migrated from the legacy ObjBase / ' + 'CollectionBase pipeline. Please re-create the project from its ' + 'underlying data using the current API.' + ) + if file_format != self.FILE_FORMAT: + raise ValueError( + f'Unsupported project file_format={file_format!r}; this version of ' + f'easyreflectometry only reads file_format={self.FILE_FORMAT}. Please ' + 'either update easyreflectometry or re-create the project.' + ) self._info = project_dict['info'] self._with_experiments = project_dict['with_experiments'] if 'calculator' in keys: diff --git a/src/easyreflectometry/sample/assemblies/base_assembly.py b/src/easyreflectometry/sample/assemblies/base_assembly.py index 39cb2b18..fcb2f06d 100644 --- a/src/easyreflectometry/sample/assemblies/base_assembly.py +++ b/src/easyreflectometry/sample/assemblies/base_assembly.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -from typing import Any from typing import Optional from ..base_core import BaseCore @@ -17,28 +16,33 @@ class BaseAssembly(BaseCore): its index number depends on the number of finite layers in the system, but it might be accessed at index -1. """ - # Added in super().__init__ - #: Name of the assembly. - name: str - #: Layers in the assembly. - layers: LayerCollection - #: Interface to the calculator. - interface: Any - def __init__( self, name: str, type: str, interface, - **layers: LayerCollection, + layers: LayerCollection, + unique_name: Optional[str] = None, ): - super().__init__(name=name, interface=interface, **layers) + super().__init__(name=name, unique_name=unique_name) + self._layers = layers # Type is needed when fitting in easyscience self._type = type self._roughness_constraints_setup = False self._thickness_constraints_setup = False + if interface is not None: + self.interface = interface + + @property + def layers(self) -> LayerCollection: + return self._layers + + @layers.setter + def layers(self, value: LayerCollection) -> None: + self._layers = value + @property def type(self) -> str: """Get type of the assembly. diff --git a/src/easyreflectometry/sample/assemblies/bilayer.py b/src/easyreflectometry/sample/assemblies/bilayer.py index 5409d359..cf6edc75 100644 --- a/src/easyreflectometry/sample/assemblies/bilayer.py +++ b/src/easyreflectometry/sample/assemblies/bilayer.py @@ -143,7 +143,6 @@ def __init__( interface=interface, ) - self.interface = interface self._conformal_roughness = False self._constrain_heads = False self._tail_constraints_setup = False @@ -271,8 +270,8 @@ def _create_back_tail_layer( molecular_formula=front_tail_layer.molecular_formula, thickness=front_tail_layer.thickness.value, solvent=solvent, - solvent_fraction=front_tail_layer.solvent_fraction, - area_per_molecule=front_tail_layer.area_per_molecule, + solvent_fraction=front_tail_layer.solvent_fraction.value, + area_per_molecule=front_tail_layer.area_per_molecule.value, roughness=front_tail_layer.roughness.value, name=front_tail_layer.name + ' Back', unique_name=unique_name + '_LayerAreaPerMoleculeBackTail', @@ -534,21 +533,17 @@ def _dict_repr(self) -> dict: } } - def as_dict(self, skip: list[str] | None = None) -> dict: - """Produce a cleaned dict using a custom as_dict method. - - The resulting dict matches the parameters in __init__ + def to_dict(self, skip: list[str] | None = None) -> dict: + """Serialize, dropping derived fields. - Parameters - ---------- - skip : list[str] | None, optional - List of keys to skip. By default, None. + The `back_tail_layer` and the underlying `layers` collection are + derived in ``__init__`` from the front head / front tail / back head + constructor arguments, so they are not part of the persisted state. """ - this_dict = super().as_dict(skip=skip) - this_dict['front_head_layer'] = self.front_head_layer.as_dict(skip=skip) - this_dict['front_tail_layer'] = self.front_tail_layer.as_dict(skip=skip) - this_dict['back_head_layer'] = self.back_head_layer.as_dict(skip=skip) - this_dict['constrain_heads'] = self.constrain_heads - this_dict['conformal_roughness'] = self.conformal_roughness - del this_dict['layers'] + this_dict = super().to_dict(skip=skip) + this_dict.pop('layers', None) return this_dict + + def as_dict(self, skip: list[str] | None = None) -> dict: + """Compatibility alias for :meth:`to_dict`.""" + return self.to_dict(skip=skip) diff --git a/src/easyreflectometry/sample/assemblies/gradient_layer.py b/src/easyreflectometry/sample/assemblies/gradient_layer.py index 1e08bb2e..fa38879c 100644 --- a/src/easyreflectometry/sample/assemblies/gradient_layer.py +++ b/src/easyreflectometry/sample/assemblies/gradient_layer.py @@ -53,15 +53,12 @@ def __init__( if front_material is None: front_material = Material(0.0, 0.0, 'Air') - self._front_material = front_material if back_material is None: back_material = Material(6.36, 0.0, 'D2O') - self._back_material = back_material if discretisation_elements < 2: raise ValueError('Discretisation elements must be greater than 2.') - self._discretisation_elements = discretisation_elements gradient_layers = _prepare_gradient_layers( front_material=front_material, @@ -74,9 +71,12 @@ def __init__( layers=gradient_layers, name=name, unique_name=unique_name, - interface=interface, + interface=None, type='Gradient-layer', ) + self._front_material = front_material + self._back_material = back_material + self._discretisation_elements = discretisation_elements self._setup_thickness_constraints() self._enable_thickness_constraints() @@ -87,6 +87,21 @@ def __init__( self.thickness = thickness self.roughness = roughness + if interface is not None: + self.interface = interface + + @property + def front_material(self) -> Material: + return self._front_material + + @property + def back_material(self) -> Material: + return self._back_material + + @property + def discretisation_elements(self) -> int: + return self._discretisation_elements + @property def thickness(self) -> float: """Get the thickness of the gradient layer in Angstrom.""" @@ -129,21 +144,31 @@ def _dict_repr(self) -> dict[str, str]: 'front_layer': self.front_layer._dict_repr, } - def as_dict(self, skip: Optional[list[str]] = None) -> dict: - """Produces a cleaned dict using a custom as_dict method to skip necessary things. + def to_dict(self, skip: Optional[list[str]] = None) -> dict: + """Produces a cleaned dict using a custom to_dict method to skip necessary things. - The resulting dict matches the parameters in __init__ + The resulting dict matches the parameters in __init__: layers are derived + in ``__init__`` from ``front_material``/``back_material``/``discretisation_elements`` + so they are excluded from the serialized representation. Parameters ---------- skip : Optional[list[str]], optional List of keys to skip. By default, None. """ - this_dict = super().as_dict(skip=skip) + this_dict = super().to_dict(skip=skip) # Determined in __init__ - del this_dict['layers'] + this_dict.pop('layers', None) + # `thickness` / `roughness` are read-only float views; the serialized + # constructor args are the floats themselves. + this_dict['thickness'] = float(self.thickness) + this_dict['roughness'] = float(self.roughness) return this_dict + def as_dict(self, skip: Optional[list[str]] = None) -> dict: + """Compatibility alias for :meth:`to_dict`.""" + return self.to_dict(skip=skip) + def _linear_gradient( front_value: float, diff --git a/src/easyreflectometry/sample/assemblies/multilayer.py b/src/easyreflectometry/sample/assemblies/multilayer.py index 9144b3a8..db02592d 100644 --- a/src/easyreflectometry/sample/assemblies/multilayer.py +++ b/src/easyreflectometry/sample/assemblies/multilayer.py @@ -62,7 +62,13 @@ def __init__( # Else collisions might occur in global_object.map self.populate_if_none = False - super().__init__(name, unique_name=unique_name, layers=layers, type=type, interface=interface) + super().__init__( + name=name, + type=type, + interface=interface, + layers=layers, + unique_name=unique_name, + ) def add_layer(self, *layers: tuple[Layer]) -> None: """Add a layer to the multi layer. diff --git a/src/easyreflectometry/sample/assemblies/repeating_multilayer.py b/src/easyreflectometry/sample/assemblies/repeating_multilayer.py index cb396836..79eab4fa 100644 --- a/src/easyreflectometry/sample/assemblies/repeating_multilayer.py +++ b/src/easyreflectometry/sample/assemblies/repeating_multilayer.py @@ -75,9 +75,6 @@ def __init__( layers = LayerCollection(layers, name=layers.name) elif isinstance(layers, list): layers = LayerCollection(*layers, name='/'.join([layer.name for layer in layers])) - # Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict - # Else collisions might occur in global_object.map - self.populate_if_none = False repetitions = get_as_parameter( name='repetitions', @@ -89,11 +86,23 @@ def __init__( super().__init__( layers=layers, name=name, - interface=interface, + unique_name=unique_name, + interface=None, type='Repeating Multi-layer', + populate_if_none=False, ) - self._add_component('repetitions', repetitions) - self.interface = interface + self._repetitions = repetitions + + if interface is not None: + self.interface = interface + + @property + def repetitions(self) -> Parameter: + return self._repetitions + + @repetitions.setter + def repetitions(self, value) -> None: + self._repetitions.value = value # Representation @property diff --git a/src/easyreflectometry/sample/assemblies/surfactant_layer.py b/src/easyreflectometry/sample/assemblies/surfactant_layer.py index 81c9c36c..8e81e097 100644 --- a/src/easyreflectometry/sample/assemblies/surfactant_layer.py +++ b/src/easyreflectometry/sample/assemblies/surfactant_layer.py @@ -57,7 +57,6 @@ def __init__( interface : Calculator interface. By default, None. """ - # We need to generate a unique name to create the nested objects if unique_name is None: unique_name = global_object.generate_unique_name(self.__class__.__name__) @@ -114,11 +113,13 @@ def __init__( interface=interface, ) - self.interface = interface self.conformal = False + if constrain_area_per_molecule: + self.constrain_area_per_molecule = True if conformal_roughness: self._enable_roughness_constraints() + self.conformal = True @property def tail_layer(self) -> Optional[LayerAreaPerMolecule]: @@ -276,20 +277,13 @@ def _dict_repr(self) -> dict: } } - def as_dict(self, skip: Optional[list[str]] = None) -> dict: - """Produces a cleaned dict using a custom as_dict method to skip necessary things. - - The resulting dict matches the parameters in __init__ - - Parameters - ---------- - skip : Optional[list[str]], optional - List of keys to skip. By default, None. + def to_dict(self, skip: Optional[list[str]] = None) -> dict: + """Serialize, dropping the derived ``layers`` field (it is rebuilt + from ``tail_layer`` and ``head_layer`` in ``__init__``). """ - this_dict = super().as_dict(skip=skip) - this_dict['tail_layer'] = self.tail_layer.as_dict(skip=skip) - this_dict['head_layer'] = self.head_layer.as_dict(skip=skip) - this_dict['constrain_area_per_molecule'] = self.constrain_area_per_molecule - this_dict['conformal_roughness'] = self.conformal_roughness - del this_dict['layers'] + this_dict = super().to_dict(skip=skip) + this_dict.pop('layers', None) return this_dict + + def as_dict(self, skip: Optional[list[str]] = None) -> dict: + return self.to_dict(skip=skip) diff --git a/src/easyreflectometry/sample/base_core.py b/src/easyreflectometry/sample/base_core.py index 4cf600a6..611665cd 100644 --- a/src/easyreflectometry/sample/base_core.py +++ b/src/easyreflectometry/sample/base_core.py @@ -1,49 +1,240 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from abc import abstractmethod +from typing import Any +from typing import Optional -from easyscience import ObjBase as BaseObj +from easyscience.base_classes import ModelBase from easyreflectometry.utils import yaml_dump -class BaseCore(BaseObj): +class BaseCore(ModelBase): + """Local base class for sample-tree objects (Material, Layer, assemblies). + + Built on top of `easyscience.base_classes.ModelBase` (the replacement for + the deprecated `ObjBase`). On top of `ModelBase` this class adds: + + - a `name` property + - an `interface` property whose setter propagates the calculator interface + to child objects and (re)generates bindings + - a yaml-formatted `__repr__` driven by an abstract `_dict_repr` + - an `_get_linkable_attributes` compatibility shim used by the calculator's + `InterfaceFactoryTemplate.generate_bindings` + - an `as_dict` alias for `to_dict` + + Subclass `__init__` convention: + 1. Build child Parameters / sub-objects. + 2. Call ``super().__init__(name=..., unique_name=...)``. + 3. Assign children to backing fields (``self._sld = sld`` etc.) or pass + them as ``**kwargs`` to this base class (transitional path; each + kwarg is stored as a plain instance attribute). + 4. Last: ``self.interface = interface`` (triggers ``generate_bindings``). + """ + def __init__( self, name: str, - interface, - **kwargs, + interface: Any = None, + unique_name: Optional[str] = None, + display_name: Optional[str] = None, + **kwargs: Any, ): - """Init function.""" - super().__init__(name=name, **kwargs) + super().__init__(unique_name=unique_name, display_name=display_name) + self._name = name + self._interface = None + # `user_data` is part of the legacy `BasedBase` API — a free-form dict + # callers stash arbitrary metadata in. Kept for back-compat with code + # like `Project.replace_models_from_orso` which stores the ORSO sample + # name on the model. + self.user_data: dict = {} - # Updates interface using property in base object - self.interface = interface + # Transitional path: subclasses still pass parameter / child objects via + # **kwargs (legacy `ObjBase` accepted them and stashed in `_kwargs`). + # Here we simply store each one as a plain instance attribute so + # `obj.` keeps working. Step 2 of the migration replaces this with + # explicit assignments in each subclass. + for key, value in kwargs.items(): + setattr(self, key, value) - @abstractmethod - def _dict_repr(self) -> dict[str, str]: ... + # Assign interface LAST so children exist when generate_bindings runs. + if interface is not None: + self.interface = interface - def __repr__(self) -> str: - """String representation of the layer. + # ----- name ----- + + @property + def name(self) -> str: + """Common (display-friendly) name.""" + return self._name + + @name.setter + def name(self, value: str) -> None: + self._name = value + + # ----- interface ----- + + @property + def interface(self) -> Any: + """The calculator interface attached to this object (may be None).""" + return self._interface - Returns - ------- - str - A string representation of the layer. + @interface.setter + def interface(self, new_interface: Any) -> None: + self._interface = new_interface + if new_interface is not None: + self.generate_bindings() + + def generate_bindings(self) -> None: + """Propagate the interface to child objects, then bind via the calculator. + + We propagate to any child whose class advertises an ``interface`` property + with a setter. That includes both the new `BaseCore`-based children and + legacy `BasedBase`-derived collections (which extend `SerializerComponent`, + not `NewBase`). """ - return yaml_dump(self._dict_repr) + if self._interface is None: + raise AttributeError('Interface error for generating bindings. `interface` has to be set.') + for attr in self._iter_public_children(): + if self._has_interface_setter(type(attr)): + attr.interface = self._interface + self._interface.generate_bindings(self) + + def _iter_public_children(self): + """Yield public child objects from both class-level (dir) and instance-level (__dict__) attrs. + + `NewBase.__dir__` exposes only class attributes, which means plain instance + attributes (the transitional `setattr(self, key, value)` path in + `__init__`) are invisible to a pure `dir()` scan. To bridge both worlds — + legacy subclasses that still use plain attrs, and migrated subclasses that + expose children via `@property` accessors — this helper unions the two. + Once all subclasses migrate to `@property`-backed children with `_field` + backing storage, the `__dict__` branch becomes a no-op (private names are + skipped). + """ + seen_ids = {id(self)} + # Class-level (properties, methods named like sld/isld/material). + for attr_name in dir(self): + if attr_name.startswith('_') or attr_name in ('interface', 'name'): + continue + try: + attr = getattr(self, attr_name, None) + except AttributeError: + # A subclass @property may legitimately raise AttributeError + # mid-construction (the `_field` backing isn't set yet); skip + # those entries silently. Other exceptions should propagate. + continue + if attr is None or id(attr) in seen_ids: + continue + seen_ids.add(id(attr)) + yield attr + # Instance-level (plain attrs set via the transitional kwargs path). + for attr_name, attr in list(self.__dict__.items()): + if attr_name.startswith('_') or attr_name in ('interface', 'name'): + continue + if attr is None or id(attr) in seen_ids: + continue + seen_ids.add(id(attr)) + yield attr + + @staticmethod + def _has_interface_setter(obj_type: type) -> bool: + for klass in obj_type.__mro__: + prop = klass.__dict__.get('interface') + if isinstance(prop, property): + return prop.fset is not None + return False + + # ----- compatibility shims ----- - # For classes with special serialization needs one must adopt the dict produced by super - # def as_dict(self, skip: list = None) -> dict: - # """Should produce a cleaned dict that matches the parameters in __init__ - # - # :param skip: List of keys to skip, defaults to `None`. - # """ - # if skip is None: - # skip = [] - # this_dict = super().as_dict(skip=skip) - # ... - # Correct the dict here - # ... - # return this_dict + def _get_linkable_attributes(self): + """Used by `easyscience.fitting.calculators.interface_factory.generate_bindings`. + + Returns the same set as :meth:`get_all_variables` (the modern API on + :class:`ModelBase`). Kept under the legacy name because the calculator + in `easyscience` core has not yet been updated. + """ + return self.get_all_variables() + + def get_parameters(self): + """Compatibility shim for legacy callers; prefer :meth:`get_all_parameters`.""" + return self.get_all_parameters() + + def _add_component(self, key: str, component: Any) -> None: + """Compatibility shim for legacy `ObjBase._add_component`. + + Legacy callers (e.g. `LayerAreaPerMolecule`) used this to register an + additional child after `super().__init__`. In the new world the + equivalent is simply setting an attribute; we do that here so the + existing call sites keep working until Step 2 removes them. + """ + setattr(self, key, component) + + def get_all_variables(self): + """Discover Parameters/Descriptors across both class-level and instance-level attrs. + + `ModelBase.get_all_variables` walks `dir(self)`, which `NewBase` restricts + to class attributes only. During the transition some subclasses still + store child Parameters as plain instance attributes (see the kwargs path + in :meth:`__init__`); those are invisible to `dir()`. We therefore also + scan `self.__dict__` for `DescriptorBase` instances and for child + ModelBase objects whose own `get_all_variables` we recurse into. + """ + from easyscience.variable.descriptor_base import DescriptorBase + + out: list = [] + seen_param_ids: set[int] = set() + for attr in self._iter_public_children(): + if isinstance(attr, DescriptorBase): + if id(attr) not in seen_param_ids: + seen_param_ids.add(id(attr)) + out.append(attr) + elif hasattr(attr, 'get_all_variables'): + for v in attr.get_all_variables(): + if id(v) not in seen_param_ids: + seen_param_ids.add(id(v)) + out.append(v) + return out + + def to_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: + """Serialize, skipping the calculator interface and unique_name by default. + + The calculator (`CalculatorFactory`) is not serializable and is not part + of the model's persistent state — round-trip code that needs it back + reattaches it after `from_dict`. The legacy `ObjBase`-based pipeline + achieved the same by never including `interface` in its `_kwargs` + encoding; we replicate that here. + + `unique_name` is also stripped by default, matching the legacy + `BasedBase.as_dict` contract. The installed `SerializerBase` does *not* + propagate per-object `_default_unique_name` to nested children — if + we leave it in, child Parameters end up with explicit unique_names in + the dict (e.g. `Parameter_0`) that subsequently collide on reload when + the global counter restarts from 0. + + Pass a *copy* of `skip` to super since `NewBase.to_dict` mutates the + list (appends `unique_name` / `display_name` if those are default). + """ + skip = list(skip or []) + if 'interface' not in skip: + skip.append('interface') + if 'unique_name' not in skip: + skip.append('unique_name') + return super().to_dict(skip=list(skip)) + + def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: + """Compatibility alias for :meth:`to_dict`.""" + return self.to_dict(skip=skip) + + # ----- repr ----- + + @property + @abstractmethod + def _dict_repr(self) -> dict[str, Any]: ... + + def __repr__(self) -> str: + """Yaml-formatted multi-line string built from :attr:`_dict_repr`.""" + return yaml_dump(self._dict_repr) diff --git a/src/easyreflectometry/sample/collections/base_collection.py b/src/easyreflectometry/sample/collections/base_collection.py index 28d494a5..59f7038e 100644 --- a/src/easyreflectometry/sample/collections/base_collection.py +++ b/src/easyreflectometry/sample/collections/base_collection.py @@ -1,118 +1,331 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +from typing import Any from typing import List from typing import Optional -from easyscience import global_object -from easyscience.base_classes import CollectionBase as EasyBaseCollection +from easyscience.base_classes import EasyList +from easyscience.base_classes.new_base import NewBase +from easyscience.variable import Parameter from easyreflectometry.utils import yaml_dump -class BaseCollection(EasyBaseCollection): +class BaseCollection(EasyList): + """Local base for sample-tree collections (Material/Layer/Assembly/Model collections). + + Built on top of `easyscience.base_classes.EasyList` (the replacement for + the deprecated `CollectionBase`). On top of `EasyList` this class adds: + + - a `name` property + - an `interface` property whose setter propagates the calculator interface + to every contained item + - propagation of the current `interface` to newly inserted items + - a `populate_if_none` flag preserved for serialization round-trip + - convenience helpers `names`, `move_up`, `move_down`, `remove_at` + - a yaml-formatted `__repr__` driven by `_dict_repr` + - an `as_dict` alias for `to_dict` with `skip=` support and `interface` + excluded by default + + Subclasses (`LayerCollection`, `MaterialCollection`, `Sample`, + `ModelCollection`) keep their existing constructor shape — they pass + `name` and `interface` positionally to this class, items as `*args`, and + additional configuration as kwargs. + """ + def __init__( self, name: str, - interface, - *args, + interface: Any = None, + *args: Any, unique_name: Optional[str] = None, - **kwargs, + populate_if_none: bool = False, + **kwargs: Any, ): - """Init function.""" - if unique_name is None: - unique_name = global_object.generate_unique_name(self.__class__.__name__) + # `_interface` must exist before `super().__init__` because `EasyList` + # calls `self.append(item)` for each positional arg, which routes + # through our `insert` override and reads `self._interface`. + self._interface = None + self._name = name + # Legacy `CollectionBase` accepted items either positionally or as a + # list-valued keyword (e.g. ``LayerCollection(layers=[a, b])``). Pull + # any list-valued kwarg into the positional stream so callers using + # that older pattern keep working. + extra_items = [] + for key in list(kwargs.keys()): + if isinstance(kwargs[key], list) and kwargs[key] and key != 'data': + extra_items.extend(kwargs.pop(key)) + if extra_items: + args = tuple(args) + tuple(extra_items) + super().__init__(*args, unique_name=unique_name, **kwargs) + # `populate_if_none` is a control flag, not state. It is serialized so + # `from_dict` knows whether the original construction filled in + # defaults; the value should be `False` once the data is restored. + self.populate_if_none = populate_if_none - super().__init__(name, unique_name=unique_name, *args, **kwargs) - self.interface = interface + # Assign interface LAST — propagates to all contained items. + if interface is not None: + self.interface = interface - # Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict - # Else collisions might occur in global_object.map - self.populate_if_none = False + # ----- name ----- - def __repr__(self) -> str: - """String representation of the collection. + @property + def name(self) -> str: + return self._name - Returns - ------- - str - A string representation of the collection. - """ - return yaml_dump(self._dict_repr) + @name.setter + def name(self, value: str) -> None: + self._name = value + + # ----- interface ----- @property - def names(self) -> list: - """Names function. + def interface(self) -> Any: + return self._interface + + @interface.setter + def interface(self, new_interface: Any) -> None: + self._interface = new_interface + if new_interface is None: + return + # Propagate to existing items. + for item in self._data: + if self._has_interface_setter(type(item)): + item.interface = new_interface + # Tell the calculator to bind to self (matches the legacy CollectionBase + # behavior which called `interface.generate_bindings(self)` once per + # collection). + if hasattr(new_interface, 'generate_bindings'): + new_interface.generate_bindings(self) + + def _get_key(self, obj): + """Use the item's `name` for string-indexed lookups. - Returns - ------- - s : list - List of names for the elements in the collection. + Matches the legacy `CollectionBase.__getitem__` behaviour which + searched by `item.name`. `EasyList` defaults to `unique_name`; the + existing `Project` code (and callers) look items up by their pretty + name (e.g. `materials['Air']`). """ - return [i.name for i in self] + return getattr(obj, 'name', None) or obj.unique_name - def move_up(self, index: int): - """Move the element at the given index up in the collection. + @staticmethod + def _has_interface_setter(obj_type: type) -> bool: + for klass in obj_type.__mro__: + prop = klass.__dict__.get('interface') + if isinstance(prop, property): + return prop.fset is not None + return False - Parameters - ---------- - index : int - Index of the element to move up. + # ----- mutable-sequence overrides that propagate interface ----- + + def insert(self, index: int, value: Any) -> None: + """Insert and (if an interface is set) propagate it to the new item. + + Legacy `CollectionBase.insert` did `value.interface = self.interface` + after registering the item; we replicate the same behaviour here so + downstream calculator state stays in sync when items are appended + after the collection's interface was already set. + + The type check from `EasyList.insert` is bypassed: each subclass + accepts a single item type (Layer / Material / BaseAssembly / Model) + and enforces it elsewhere, while the EasyList check would require + every item to be a `NewBase` subclass — which was historically not + guaranteed and forces an extra coupling we don't need. """ - if index == 0: + if not isinstance(index, int): + raise TypeError('Index must be an integer') + # Skip the EasyList protected-types check; mimic the rest of its insert + # (duplicate-name warning + append-to-_data). + import warnings as _warnings + + if value in self: + _warnings.warn(f'Item with unique name "{self._get_key(value)}" already in collection, it will be ignored') return - self.insert(index - 1, self.pop(index)) + self._data.insert(index, value) + if self._interface is not None and self._has_interface_setter(type(value)): + value.interface = self._interface - def move_down(self, index: int): - """Move the element at the given index down in the collection. + # ----- helpers ----- - Parameters - ---------- - index : int - Index of the element to move down. + @property + def names(self) -> list: + """List of item names.""" + return [getattr(item, 'name', None) for item in self._data] + + @property + def data(self) -> list: + """Read-only view of the underlying item list. + + Provided for compatibility with code (and tests) written against the + legacy `CollectionBase` shape, which exposed items via `.data`. """ + return list(self._data) + + def move_up(self, index: int) -> None: + """Move the element at the given index up in the collection.""" + if index == 0: + return + self.insert(index - 1, self.pop(index)) + + def move_down(self, index: int) -> None: + """Move the element at the given index down in the collection.""" if index == len(self) - 1: return self.insert(index + 1, self.pop(index)) - def remove(self, index: int): - """Remove an element from the elements. + def remove_at(self, index: int) -> None: + """Remove the item at *index* from the collection. - Parameters - ---------- - index : int - Index of the element to remove. + Renamed from the legacy `BaseCollection.remove(index)` which shadowed + `MutableSequence.remove(value)` (remove-by-value, inherited from + `EasyList`). Callers that meant "remove by index" should use this; the + standard `remove(value)` is still available with its usual semantics. """ self.pop(index) + # ----- compatibility shims (kept until call sites migrate) ----- + + def get_parameters(self) -> List[Parameter]: + """Compatibility alias for legacy callers; prefer `get_all_parameters`.""" + return self.get_all_parameters() + + def get_all_variables(self) -> List: + """Flat list of every Parameter/Descriptor across all items. + + Walks each item in the collection and unions whatever each item + exposes via its own `get_all_variables` (for `ModelBase` / + `BaseCore` children) or, for objects that lack that hook, + unions any direct `DescriptorBase` attributes. + """ + from easyscience.variable.descriptor_base import DescriptorBase + + out: list = [] + seen: set[int] = set() + for item in self._data: + if hasattr(item, 'get_all_variables'): + for v in item.get_all_variables(): + if id(v) not in seen: + seen.add(id(v)) + out.append(v) + elif isinstance(item, DescriptorBase): + if id(item) not in seen: + seen.add(id(item)) + out.append(item) + return out + + def get_all_parameters(self) -> List[Parameter]: + return [v for v in self.get_all_variables() if isinstance(v, Parameter)] + + def get_free_parameters(self) -> List[Parameter]: + return [p for p in self.get_all_parameters() if p.independent and not p.fixed] + + def get_fit_parameters(self) -> List[Parameter]: + """Alias kept for the minimizer; matches `ModelBase.get_fit_parameters`.""" + return self.get_free_parameters() + + def _get_linkable_attributes(self) -> List[Parameter]: + """Bridge for `easyscience.fitting.calculators.interface_factory.generate_bindings`.""" + return self.get_all_variables() + + # ----- repr ----- + @property def _dict_repr(self) -> dict: - """A simplified dict representation. + """A simplified dict representation.""" + return {self.name: [getattr(i, '_dict_repr', repr(i)) for i in self._data]} - Returns - ------- - dict - Simple dictionary. - """ - return {self.name: [i._dict_repr for i in self]} + def __repr__(self) -> str: + try: + return yaml_dump(self._dict_repr) + except Exception: + return super().__repr__() - def as_dict(self, skip: Optional[List[str]] = None) -> dict: - """Create a dictionary representation of the collection. + # ----- serialization ----- - Returns - ------- - dict - A dictionary representation of the collection. + def _convert_to_dict(self, d: dict, encoder=None, skip: Optional[List[str]] = None, **kwargs) -> dict: + """Serializer hook used when this collection is encoded as a *child* + attribute (e.g. `Multilayer.layers`). + + `SerializerBase._convert_to_dict` iterates `_arg_spec` to populate the + base dict and then calls `obj._convert_to_dict(d, ...)` if defined. + Because `data` is supplied via `*args` (VAR_POSITIONAL — not part of + `_arg_spec`), without this hook the nested encoding would miss the + items entirely and round-trip would reconstruct an empty collection. """ if skip is None: skip = [] - this_dict = super().as_dict(skip=skip) - this_dict['data'] = [] - for collection_element in self: - this_dict['data'].append(collection_element.as_dict(skip=skip)) - this_dict['populate_if_none'] = self.populate_if_none - return this_dict + if self._protected_types != [NewBase] and 'protected_types' not in d: + d['protected_types'] = [{'@module': c.__module__, '@class': c.__name__} for c in self._protected_types] + # Encode each item. Defer to the encoder's recursive walk so nested + # ModelBase / NewBase items get properly serialized. + item_skip = list(skip) + items: list = [] + for item in self._data: + if encoder is not None and hasattr(encoder, '_recursive_encoder'): + items.append(encoder._recursive_encoder(item, skip=item_skip, encoder=encoder, full_encode=False)) + elif hasattr(item, 'to_dict'): + try: + items.append(item.to_dict(skip=list(item_skip))) + except TypeError: + items.append(item.to_dict()) + else: + items.append(item) + d['data'] = items + return d + + def to_dict(self, skip: Optional[List[str]] = None) -> dict: + """Serialize with `skip` support; `interface` excluded by default. + + `EasyList.to_dict` doesn't accept a `skip` argument and is hard-wired + to dump `data` plus the parent's `_arg_spec` view. We reimplement here + so existing callers (`Project`, `Model.as_dict`, etc.) can keep + passing `skip=['unique_name']` or similar. + + ``NewBase.to_dict`` mutates the ``skip`` list in-place (e.g. it + appends ``'unique_name'`` when the collection's own unique_name is + default-generated). We therefore pass a *copy* to it, otherwise the + mutation would leak into the per-item serialization below and force + every item Parameter dict to drop its ``unique_name`` — breaking the + from_dict round-trip. + """ + skip = list(skip or []) + if 'interface' not in skip: + skip.append('interface') + # Matches legacy `BasedBase.as_dict`: drop unique_name from the + # serialized form so nested Parameters don't get explicit names that + # would collide with the auto-generated names produced when the global + # counter restarts during reconstruction. + if 'unique_name' not in skip: + skip.append('unique_name') + dict_repr = NewBase.to_dict(self, skip=list(skip)) + if self._protected_types != [NewBase]: + dict_repr['protected_types'] = [{'@module': c.__module__, '@class': c.__name__} for c in self._protected_types] + dict_repr['data'] = [] + for item in self._data: + # Items that are ModelBase / BaseCore subclasses accept `skip`; + # other shapes use their no-arg `to_dict`/`as_dict`. + if hasattr(item, 'to_dict'): + try: + dict_repr['data'].append(item.to_dict(skip=list(skip))) + except TypeError: + dict_repr['data'].append(item.to_dict()) + else: + dict_repr['data'].append(item.as_dict(skip=list(skip))) + return dict_repr + + def as_dict(self, skip: Optional[List[str]] = None) -> dict: + """Compatibility alias for :meth:`to_dict`.""" + return self.to_dict(skip=skip) def __deepcopy__(self, memo): - """Deepcopy function.""" + """Round-trip via dict-skip-unique to get a fresh copy. + + `NewBase.__copy__` already does this; the override is kept (rather + than deleted) to mirror the legacy `BaseCollection.__deepcopy__` + semantics — callers that relied on `copy.deepcopy(collection)` still + get a clone built from `from_dict(as_dict(skip=['unique_name']))`. + """ return self.from_dict(self.as_dict(skip=['unique_name'])) diff --git a/src/easyreflectometry/sample/collections/layer_collection.py b/src/easyreflectometry/sample/collections/layer_collection.py index 99424f72..a3ae0f19 100644 --- a/src/easyreflectometry/sample/collections/layer_collection.py +++ b/src/easyreflectometry/sample/collections/layer_collection.py @@ -15,14 +15,21 @@ def __init__( name: str = 'EasyLayerCollection', interface=None, unique_name: Optional[str] = None, - populate_if_none: bool = True, # Needed to match as_dict signature from BaseCollection + populate_if_none: bool = True, **kwargs, ): """Init function.""" if not layers: layers = [] - super().__init__(name, interface, unique_name=unique_name, *layers, **kwargs) + super().__init__( + name, + interface, + *layers, + unique_name=unique_name, + populate_if_none=populate_if_none, + **kwargs, + ) def add_layer(self, layer: Optional[Layer] = None): """Add a layer to the collection. diff --git a/src/easyreflectometry/sample/collections/material_collection.py b/src/easyreflectometry/sample/collections/material_collection.py index d7b5af3b..726678b2 100644 --- a/src/easyreflectometry/sample/collections/material_collection.py +++ b/src/easyreflectometry/sample/collections/material_collection.py @@ -30,7 +30,7 @@ def __init__( **kwargs, ): """Init function.""" - if not materials: # Empty tuple if no materials are provided + if not materials: if populate_if_none: materials = DEFAULT_ELEMENTS(interface) else: @@ -39,8 +39,9 @@ def __init__( super().__init__( name, interface, - unique_name=unique_name, *materials, + unique_name=unique_name, + populate_if_none=False, **kwargs, ) diff --git a/src/easyreflectometry/sample/collections/sample.py b/src/easyreflectometry/sample/collections/sample.py index 1bbdb156..384626e8 100644 --- a/src/easyreflectometry/sample/collections/sample.py +++ b/src/easyreflectometry/sample/collections/sample.py @@ -52,6 +52,11 @@ def __init__( interface : Calculator interface. By default, None. """ + # `from_dict` (via `EasyList.from_dict`) passes the items as a single + # list-positional arg; unpack that so validation and super() agree. + if len(assemblies) == 1 and isinstance(assemblies[0], list): + assemblies = tuple(assemblies[0]) + if not assemblies: if populate_if_none: assemblies = DEFAULT_ELEMENTS(interface) @@ -61,7 +66,14 @@ def __init__( for assembly in assemblies: if not issubclass(type(assembly), BaseAssembly): raise ValueError('The elements must be an Assembly.') - super().__init__(name, interface, unique_name=unique_name, *assemblies, **kwargs) + super().__init__( + name, + interface, + *assemblies, + unique_name=unique_name, + populate_if_none=populate_if_none, + **kwargs, + ) def add_assembly(self, assembly: Optional[BaseAssembly] = None): """Add an assembly to the sample. @@ -87,13 +99,19 @@ def duplicate_assembly(self, index: int): assembly : Assembly to add. """ + # Order matters: RepeatingMultilayer and SurfactantLayer are subclasses of + # BaseAssembly but not Multilayer; however a RepeatingMultilayer IS a + # Multilayer, so the most-specific check must come first to avoid + # serialising it through the wrong `from_dict`. to_be_duplicated = self[index] - if isinstance(to_be_duplicated, Multilayer): - duplicate = Multilayer.from_dict(to_be_duplicated.as_dict(skip=['unique_name'])) - elif isinstance(to_be_duplicated, RepeatingMultilayer): + if isinstance(to_be_duplicated, RepeatingMultilayer): duplicate = RepeatingMultilayer.from_dict(to_be_duplicated.as_dict(skip=['unique_name'])) elif isinstance(to_be_duplicated, SurfactantLayer): duplicate = SurfactantLayer.from_dict(to_be_duplicated.as_dict(skip=['unique_name'])) + elif isinstance(to_be_duplicated, Multilayer): + duplicate = Multilayer.from_dict(to_be_duplicated.as_dict(skip=['unique_name'])) + else: + raise TypeError(f'Cannot duplicate assembly of type {type(to_be_duplicated).__name__}') duplicate.name = duplicate.name + ' duplicate' self.append(duplicate) @@ -140,18 +158,3 @@ def subphase(self) -> Layer: return self[-1].front_layer else: return self[-1].back_layer - - # Representation - def as_dict(self, skip: Optional[List[str]] = None) -> dict: - """Produces a cleaned dict using a custom as_dict method to skip necessary things. - - The resulting dict matches the parameters in __init__ - - Parameters - ---------- - skip : Optional[List[str]], optional - List of keys to skip. By default, None. - """ - this_dict = super().as_dict(skip=skip) - this_dict['populate_if_none'] = self.populate_if_none - return this_dict diff --git a/src/easyreflectometry/sample/elements/layers/layer.py b/src/easyreflectometry/sample/elements/layers/layer.py index fa58dc15..7eea9872 100644 --- a/src/easyreflectometry/sample/elements/layers/layer.py +++ b/src/easyreflectometry/sample/elements/layers/layer.py @@ -37,14 +37,6 @@ class Layer(BaseCore): - # Added in super().__init__ - #: Material that makes up the layer. - material: Material - #: Thickness of the layer in Angstrom. - thickness: Parameter - #: Roughness of the layer in Angstrom. - roughness: Parameter - def __init__( self, material: Union[Material, None] = None, @@ -95,14 +87,37 @@ def __init__( ) roughness.default_limits_pending = not isinstance(roughness_value, Parameter) - super().__init__( - name=name, - interface=interface, - material=material, - thickness=thickness, - roughness=roughness, - unique_name=unique_name, - ) + super().__init__(name=name, unique_name=unique_name) + self._material = material + self._thickness = thickness + self._roughness = roughness + + if interface is not None: + self.interface = interface + + @property + def material(self) -> Material: + return self._material + + @material.setter + def material(self, value: Material) -> None: + self._material = value + + @property + def thickness(self) -> Parameter: + return self._thickness + + @thickness.setter + def thickness(self, value: float) -> None: + self._thickness.value = value + + @property + def roughness(self) -> Parameter: + return self._roughness + + @roughness.setter + def roughness(self, value: float) -> None: + self._roughness.value = value def assign_material(self, material: Material) -> None: """Assign a material to the layer interface. @@ -112,7 +127,7 @@ def assign_material(self, material: Material) -> None: material : Material The material to assign to the layer. """ - self.material = material + self._material = material if self.interface is not None: self.interface().assign_material_to_layer(self.material.unique_name, self.unique_name) diff --git a/src/easyreflectometry/sample/elements/layers/layer_area_per_molecule.py b/src/easyreflectometry/sample/elements/layers/layer_area_per_molecule.py index 9053aa95..14b07b8c 100644 --- a/src/easyreflectometry/sample/elements/layers/layer_area_per_molecule.py +++ b/src/easyreflectometry/sample/elements/layers/layer_area_per_molecule.py @@ -55,17 +55,6 @@ class LayerAreaPerMolecule(Layer): molecular formula an area per molecule, and a solvent. """ - # Added in __init__ - #: Real part of the scattering length. - _scattering_length_real: Parameter - #: Imaginary part of the scattering length. - _scattering_length_imag: Parameter - #: Area per molecule in the layer in Anstrom^2. - _area_per_molecule: Parameter - - # Other typer than in __init__.super() - material: MaterialSolvated - def __init__( self, molecular_formula: Union[str, None] = None, @@ -113,7 +102,6 @@ def __init__( interface=interface, ) - # Create the solvated molecule and corresponding constraints if molecular_formula is None: molecular_formula = DEFAULTS['molecular_formula'] molecule_material = Material( @@ -130,39 +118,40 @@ def __init__( default_dict=DEFAULTS, unique_name_prefix=f'{unique_name}_Thickness', ) - _area_per_molecule = get_as_parameter( + area_per_molecule_param = get_as_parameter( name='area_per_molecule', value=area_per_molecule, default_dict=DEFAULTS, unique_name_prefix=f'{unique_name}_AreaPerMolecule', ) - _scattering_length_real = get_as_parameter( + scattering_length_real = get_as_parameter( name='scattering_length_real', value=0.0, default_dict=DEFAULTS['sl'], unique_name_prefix=f'{unique_name}_Sl', ) - _scattering_length_imag = get_as_parameter( + scattering_length_imag = get_as_parameter( name='scattering_length_imag', value=0.0, default_dict=DEFAULTS['isl'], unique_name_prefix=f'{unique_name}_Isl', ) - # Constrain the real part of the sld value for the molecule + + # Constrain molecule.sld via scattering length / (thickness * area_per_molecule) dependency_expression = 'scattering_length / (thickness * area_per_molecule) * 1e6' dependency_map = { - 'scattering_length': _scattering_length_real, + 'scattering_length': scattering_length_real, 'thickness': thickness, - 'area_per_molecule': _area_per_molecule, + 'area_per_molecule': area_per_molecule_param, } molecule_material.sld.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) - # # Constrain the real part of the sld value for the molecule + # Same dependency under short variable names dependency_expression = 'a / (b*p) * 1e6' - dependency_map = {'a': _scattering_length_real, 'b': thickness, 'p': _area_per_molecule} + dependency_map = {'a': scattering_length_real, 'b': thickness, 'p': area_per_molecule_param} molecule_material.sld.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) - dependency_map = {'a': _scattering_length_imag, 'b': thickness, 'p': _area_per_molecule} + dependency_map = {'a': scattering_length_imag, 'b': thickness, 'p': area_per_molecule_param} molecule_material.isld.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) solvated_molecule_material = MaterialSolvated( @@ -178,17 +167,85 @@ def __init__( roughness=roughness, name=name, unique_name=unique_name, - interface=interface, + interface=None, ) - self._add_component('_scattering_length_real', _scattering_length_real) - self._add_component('_scattering_length_imag', _scattering_length_imag) - self._add_component('_area_per_molecule', _area_per_molecule) + self._area_per_molecule = area_per_molecule_param + self._scattering_length_real = scattering_length_real + self._scattering_length_imag = scattering_length_imag scattering_length = neutron_scattering_length(molecular_formula) self._scattering_length_real.value = scattering_length.real self._scattering_length_imag.value = scattering_length.imag self._molecular_formula = molecular_formula - self.interface = interface + + if interface is not None: + self.interface = interface + + # ----- constraint plumbing ----- + + def _setup_sld_constraints(self) -> None: + """Wire the inner molecule material's ``sld`` / ``isld`` to depend on + the current scattering-length, thickness, and area-per-molecule + parameters. + + Idempotent — called once from ``__init__`` and again from + ``from_dict`` after the saved Parameter objects replace the + constructor-time temporaries. + """ + molecule_material = self.material.material + for derived in (molecule_material.sld, molecule_material.isld): + if not derived.independent: + derived.make_independent() + + dependency_expression = 'a / (b*p) * 1e6' + molecule_material.sld.make_dependent_on( + dependency_expression=dependency_expression, + dependency_map={ + 'a': self._scattering_length_real, + 'b': self._thickness, + 'p': self._area_per_molecule, + }, + ) + molecule_material.isld.make_dependent_on( + dependency_expression=dependency_expression, + dependency_map={ + 'a': self._scattering_length_imag, + 'b': self._thickness, + 'p': self._area_per_molecule, + }, + ) + + # ----- deserialization ----- + + @classmethod + def from_dict(cls, obj_dict: dict) -> 'LayerAreaPerMolecule': + """Re-route the saved ``solvent_fraction`` Parameter and rebuild the + molecule-SLD constraint chain after :class:`ModelBase.from_dict` + swaps in the persisted Parameter objects. + + `ModelBase.from_dict` writes the deserialized ``solvent_fraction`` + Parameter to ``self._solvent_fraction`` (orphan — the live property + delegates to ``self.material.solvent_fraction``, which is + ``self.material._fraction``). It also reassigns ``self._thickness`` + and ``self._area_per_molecule``, but the constraint graph built in + ``__init__`` still references the temporary Parameters created from + the float kwargs. We fix both here. + """ + instance = super().from_dict(obj_dict) + + saved_solvent_fraction = instance.__dict__.pop('_solvent_fraction', None) + if saved_solvent_fraction is not None: + mixture = instance.material + old = mixture._fraction + mixture._fraction = saved_solvent_fraction + try: + instance._global_object.map.prune(old.unique_name) + except (AttributeError, KeyError): + pass + mixture._materials_constraints() + + instance._setup_sld_constraints() + return instance @property def area_per_molecule_parameter(self) -> Parameter: @@ -196,22 +253,15 @@ def area_per_molecule_parameter(self) -> Parameter: return self._area_per_molecule @property - def area_per_molecule(self) -> float: - """Get the area per molecule.""" - return self._area_per_molecule.value + def area_per_molecule(self) -> Parameter: + """The Parameter that controls area per molecule.""" + return self._area_per_molecule @area_per_molecule.setter - def area_per_molecule(self, new_area_per_molecule: float) -> None: - """Set the area per molecule. - - Parameters - ---------- - new_area_per_molecule : float - New area per molecule. - """ - if new_area_per_molecule < 0: - raise ValueError('new_area_per_molecule must be greater than 0.0.') - self._area_per_molecule.value = new_area_per_molecule + def area_per_molecule(self, value: float) -> None: + if value < 0: + raise ValueError('area_per_molecule must be greater than 0.0.') + self._area_per_molecule.value = value @property def molecule(self) -> Material: @@ -225,40 +275,21 @@ def solvent(self) -> Material: @solvent.setter def solvent(self, new_solvent: Material) -> None: - """Set the solvent material. - - Parameters - ---------- - new_solvent : Material - New solvent material. - """ self.material.solvent = new_solvent @property - def solvent_fraction_parameter(self) -> float: + def solvent_fraction_parameter(self) -> Parameter: """Get parameter for the fraction of the layer occupied by the solvent.""" return self.material.solvent_fraction_parameter @property - def solvent_fraction(self) -> float: - """Get the fraction of the layer occupied by the solvent. - - This could be a result of either water solvating the molecule, or incomplete surface coverage of the molecules. - """ + def solvent_fraction(self) -> Parameter: + """The Parameter for the fraction of the layer occupied by the solvent.""" return self.material.solvent_fraction @solvent_fraction.setter - def solvent_fraction(self, solvent_fraction: float) -> None: - """Set the fraction of the layer occupied by the solvent. - - This could be a result of either water solvating the molecule, or incomplete surface coverage of the molecules. - - Parameters - ---------- - solvent_fraction : float - Fraction of layer described by the solvent. - """ - self.material.solvent_fraction = solvent_fraction + def solvent_fraction(self, value: float) -> None: + self.material.solvent_fraction = value @property def molecular_formula(self) -> str: @@ -267,16 +298,8 @@ def molecular_formula(self) -> str: @molecular_formula.setter def molecular_formula(self, formula_string: str) -> None: - """Set the formula of the molecule in the material. - - Parameters - ---------- - formula_string : str - String that defines the molecular formula. - """ self._molecular_formula = formula_string scattering_length = neutron_scattering_length(formula_string) - # The molecule is also being updated through the constraints self._scattering_length_real.value = scattering_length.real self._scattering_length_imag.value = scattering_length.imag @@ -285,30 +308,8 @@ def molecular_formula(self, formula_string: str) -> None: @property def _dict_repr(self) -> dict[str, str]: - """Dictionary representation of the `area_per_molecule` object. - - Produces a simple dictionary. - """ + """Dictionary representation of the `area_per_molecule` object.""" dict_repr = super()._dict_repr dict_repr['molecular_formula'] = self._molecular_formula - dict_repr['area_per_molecule'] = f'{self.area_per_molecule:.2f} {self._area_per_molecule.unit}' + dict_repr['area_per_molecule'] = f'{self._area_per_molecule.value:.2f} {self._area_per_molecule.unit}' return dict_repr - - def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, str]: - """Produces a cleaned dict using a custom as_dict method to skip necessary things. - - The resulting dict matches the parameters in __init__ - - Parameters - ---------- - skip : Optional[list[str]], optional - List of keys to skip. By default, None. - """ - this_dict = super().as_dict(skip=skip) - this_dict['solvent_fraction'] = self.material._fraction.as_dict(skip=skip) - this_dict['area_per_molecule'] = self._area_per_molecule.as_dict(skip=skip) - this_dict['solvent'] = self.solvent.as_dict(skip=skip) - del this_dict['material'] - del this_dict['_scattering_length_real'] - del this_dict['_scattering_length_imag'] - return this_dict diff --git a/src/easyreflectometry/sample/elements/materials/material.py b/src/easyreflectometry/sample/elements/materials/material.py index 42091bd6..f072a639 100644 --- a/src/easyreflectometry/sample/elements/materials/material.py +++ b/src/easyreflectometry/sample/elements/materials/material.py @@ -37,10 +37,6 @@ class Material(BaseCore): - # Added in super().__init__ - sld: Parameter - isld: Parameter - def __init__( self, sld: Union[Parameter, float, None] = None, @@ -83,13 +79,28 @@ def __init__( ) apply_default_limits(isld, 'isld') - super().__init__( - name=name, - sld=sld, - isld=isld, - interface=interface, - unique_name=unique_name, - ) + super().__init__(name=name, unique_name=unique_name) + self._sld = sld + self._isld = isld + + if interface is not None: + self.interface = interface + + @property + def sld(self) -> Parameter: + return self._sld + + @sld.setter + def sld(self, value: float) -> None: + self._sld.value = value + + @property + def isld(self) -> Parameter: + return self._isld + + @isld.setter + def isld(self, value: float) -> None: + self._isld.value = value # Representation @property @@ -97,7 +108,7 @@ def _dict_repr(self) -> dict[str, str]: """A simplified dict representation.""" return { self.name: { - 'sld': f'{self.sld.value:.3f}e-6 {self.sld.unit}', - 'isld': f'{self.isld.value:.3f}e-6 {self.isld.unit}', + 'sld': f'{self._sld.value:.3f}e-6 {self._sld.unit}', + 'isld': f'{self._isld.value:.3f}e-6 {self._isld.unit}', } } diff --git a/src/easyreflectometry/sample/elements/materials/material_density.py b/src/easyreflectometry/sample/elements/materials/material_density.py index b85bda98..5be1cf9d 100644 --- a/src/easyreflectometry/sample/elements/materials/material_density.py +++ b/src/easyreflectometry/sample/elements/materials/material_density.py @@ -41,12 +41,6 @@ class MaterialDensity(Material): - # Added in __init__ - scattering_length_real: Parameter - scattering_length_imag: Parameter - molecular_weight: Parameter - density: Parameter - def __init__( self, chemical_structure: Union[str, None] = None, @@ -123,14 +117,59 @@ def __init__( dependency_map = {'d': density, 'sl': scattering_length_imag, 'mw': mw} isld.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) - super().__init__(sld, isld, name=name, interface=interface) + super().__init__(sld=sld, isld=isld, name=name, unique_name=unique_name, interface=None) - self._add_component('scattering_length_real', scattering_length_real) - self._add_component('scattering_length_imag', scattering_length_imag) - self._add_component('molecular_weight', mw) - self._add_component('density', density) + self._scattering_length_real = scattering_length_real + self._scattering_length_imag = scattering_length_imag + self._molecular_weight = mw + self._density = density self._chemical_structure = chemical_structure - self.interface = interface + + if interface is not None: + self.interface = interface + + def _setup_sld_constraints(self) -> None: + """Wire the derived `sld` / `isld` to depend on the current density and + scattering-length Parameters. + + Idempotent — invoked once from `__init__` and again from `from_dict` + after :class:`ModelBase` has swapped in the saved Parameter objects. + """ + for derived in (self._sld, self._isld): + if not derived.independent: + derived.make_independent() + + dependency_expression = '1e-23*(0.602214076e6 * d * sl) / mw' + self._sld.make_dependent_on( + dependency_expression=dependency_expression, + dependency_map={ + 'd': self._density, + 'sl': self._scattering_length_real, + 'mw': self._molecular_weight, + }, + ) + self._isld.make_dependent_on( + dependency_expression=dependency_expression, + dependency_map={ + 'd': self._density, + 'sl': self._scattering_length_imag, + 'mw': self._molecular_weight, + }, + ) + + @classmethod + def from_dict(cls, obj_dict: dict) -> 'MaterialDensity': + """Re-attach sld/isld dependencies after deserialization. + + :class:`ModelBase.from_dict` re-points `self._density` at the + deserialized Parameter (because `density` is a constructor argument); + the constraint graph built in `__init__` still references the + temporary Parameter created from the float kwarg. Rebuild here so + `q.density = X` propagates to the derived SLDs. + """ + instance = super().from_dict(obj_dict) + instance._setup_sld_constraints() + return instance @property def chemical_structure(self) -> str: @@ -148,8 +187,28 @@ def chemical_structure(self, structure_string: str) -> None: """ self._chemical_structure = structure_string scattering_length = neutron_scattering_length(structure_string) - self.scattering_length_real.value = scattering_length.real - self.scattering_length_imag.value = scattering_length.imag + self._scattering_length_real.value = scattering_length.real + self._scattering_length_imag.value = scattering_length.imag + + @property + def density(self) -> Parameter: + return self._density + + @density.setter + def density(self, value: float) -> None: + self._density.value = value + + @property + def molecular_weight(self) -> Parameter: + return self._molecular_weight + + @property + def scattering_length_real(self) -> Parameter: + return self._scattering_length_real + + @property + def scattering_length_imag(self) -> Parameter: + return self._scattering_length_imag @property def _dict_repr(self) -> dict[str, str]: @@ -158,23 +217,3 @@ def _dict_repr(self) -> dict[str, str]: mat_dict['chemical_structure'] = self._chemical_structure mat_dict['density'] = f'{self.density.value:.2e} {self.density.unit}' return mat_dict - - def as_dict(self, skip: list = []) -> dict[str, str]: - """Produces a cleaned dict using a custom as_dict method to skip necessary things. - - The resulting dict matches the parameters in __init__ - - Parameters - ---------- - skip : list, optional - List of keys to skip. By default, []. - """ - this_dict = super().as_dict(skip=skip) - # From Material - del this_dict['sld'] - del this_dict['isld'] - # Determined in __init__ - del this_dict['scattering_length_real'] - del this_dict['scattering_length_imag'] - del this_dict['molecular_weight'] - return this_dict diff --git a/src/easyreflectometry/sample/elements/materials/material_mixture.py b/src/easyreflectometry/sample/elements/materials/material_mixture.py index c999093e..1ab76ddc 100644 --- a/src/easyreflectometry/sample/elements/materials/material_mixture.py +++ b/src/easyreflectometry/sample/elements/materials/material_mixture.py @@ -28,11 +28,6 @@ class MaterialMixture(BaseCore): - # Added in super().__init__ - _material_a: Material - _material_b: Material - _fraction: Parameter - def __init__( self, material_a: Union[Material, None] = None, @@ -74,108 +69,57 @@ def __init__( unique_name_prefix=f'{unique_name}_Fraction', ) - sld = weighted_average( + sld_value = weighted_average( a=material_a.sld.value, b=material_b.sld.value, p=fraction.value, ) - isld = weighted_average( + isld_value = weighted_average( a=material_a.isld.value, b=material_b.isld.value, p=fraction.value, ) - self._sld = get_as_parameter( + sld = get_as_parameter( name='sld', - value=sld, + value=sld_value, default_dict=DEFAULTS, unique_name_prefix=f'{unique_name}_Sld', ) - self._isld = get_as_parameter( + isld = get_as_parameter( name='isld', - value=isld, + value=isld_value, default_dict=DEFAULTS, unique_name_prefix=f'{unique_name}_Isld', ) - # To avoid problems when setting the interface - # self._sld and self._isld need to be declared before calling the super constructor - super().__init__( - name, - _material_a=material_a, - _material_b=material_b, - _fraction=fraction, - interface=interface, - ) + # `name` may be None to signal "derive from material names"; resolve + # before super().__init__ since BaseCore stores `_name` directly. if name is None: - self._update_name() - - self._materials_constraints() - self.interface = interface - - def _get_linkable_attributes(self): - """Get linkable attributes.""" - return [self._sld, self._isld] - - @property - def sld(self) -> float: - """Sld function.""" - return self._sld.value - - @property - def isld(self) -> float: - """Isld function.""" - return self._isld.value - - def _materials_constraints(self): - """Materials constraints.""" - dependency_expression = 'a * (1 - p) + b * p' - dependency_map = { - 'a': self._material_a.sld, - 'b': self._material_b.sld, - 'p': self._fraction, - } - self._sld.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) + resolved_name = material_a.name + '/' + material_b.name + else: + resolved_name = name - dependency_map = { - 'a': self._material_a.isld, - 'b': self._material_b.isld, - 'p': self._fraction, - } - self._isld.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) + super().__init__(name=resolved_name, unique_name=unique_name) + self._material_a = material_a + self._material_b = material_b + self._fraction = fraction + self._sld = sld + self._isld = isld - @property - def fraction(self) -> float: - """Get the fraction of material_b.""" - return self._fraction.value + self._materials_constraints() - @fraction.setter - def fraction(self, fraction: float) -> None: - """Setter for fraction of material_b. + if interface is not None: + self.interface = interface - Parameters - ---------- - fraction : float - The fraction of material_b in material_a. - """ - if not isinstance(fraction, float): - raise ValueError('fraction must be a float') - self._fraction.value = fraction + # ----- constructor-arg accessors ----- @property def material_a(self) -> Material: - """Getter for material_a.""" return self._material_a @material_a.setter def material_a(self, new_material_a: Material) -> None: - """Setter for material_a. - - Parameters - ---------- - new_material_a : Material - New Material for material_a. - """ self._material_a = new_material_a self._materials_constraints() if self.interface is not None: @@ -184,28 +128,108 @@ def material_a(self, new_material_a: Material) -> None: @property def material_b(self) -> Material: - """Getter for material_b.""" return self._material_b @material_b.setter def material_b(self, new_material_b: Material) -> None: - """Setter for material_b. - - Parameters - ---------- - new_material_b : Material - New Materialfor material_b. - """ self._material_b = new_material_b self._materials_constraints() if self.interface is not None: self.interface.generate_bindings(self) self._update_name() + @property + def fraction(self) -> Parameter: + """The Parameter that controls the mixing fraction of material_b in material_a.""" + return self._fraction + + @fraction.setter + def fraction(self, value: float) -> None: + if not isinstance(value, (int, float)): + raise ValueError('fraction must be a float') + self._fraction.value = value + + # ----- derived sld / isld parameters (shared shape with Material) ----- + # + # These are *derived* via the constraints set up in `_materials_constraints` + # (not constructor arguments) so we expose them as floats to match the + # legacy MaterialMixture API. The underlying Parameter objects remain + # available as `self._sld` / `self._isld`. + + @property + def sld(self) -> float: + return self._sld.value + + @property + def isld(self) -> float: + return self._isld.value + + # ----- calculator binding ----- + + def _get_linkable_attributes(self): + """Return the *mixed* sld / isld parameters for calculator binding. + + Override of the inherited `BaseCore._get_linkable_attributes`, which + walks `get_all_variables()` and would otherwise expose the **child** + materials' sld/isld (because our own `sld` / `isld` are floats, not + Parameters). The calculator's `InterfaceFactoryTemplate.generate_bindings` + matches by parameter `name`; without this override it binds to + `material_a.sld` and reflectivity is computed off the wrong SLD. + """ + return [self._sld, self._isld] + + # ----- internal helpers ----- + + def _materials_constraints(self): + """Wire the mixed `_sld` / `_isld` to depend on the current child + material parameters and the current `_fraction`. Idempotent: callers + invoke this once from ``__init__`` and again from ``from_dict`` after + the saved Parameters have been reattached (so the dependency graph + points at the right objects, not the temporary constructor params).""" + # Detach any existing dependency before rebuilding so make_dependent_on + # doesn't chain on top of stale references. + for derived in (self._sld, self._isld): + if not derived.independent: + derived.make_independent() + + dependency_expression = 'a * (1 - p) + b * p' + dependency_map = { + 'a': self._material_a.sld, + 'b': self._material_b.sld, + 'p': self._fraction, + } + self._sld.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) + + dependency_map = { + 'a': self._material_a.isld, + 'b': self._material_b.isld, + 'p': self._fraction, + } + self._isld.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) + def _update_name(self) -> None: """Update name.""" self.name = self._material_a.name + '/' + self._material_b.name + # ----- deserialization ----- + + @classmethod + def from_dict(cls, obj_dict: dict) -> 'MaterialMixture': + """Re-attach mixed-sld dependencies after :class:`ModelBase` swaps in + the saved ``_fraction`` Parameter. + + :class:`ModelBase.from_dict` runs ``__init__`` (which builds the + ``_sld`` / ``_isld`` constraints against the *temporary* ``_fraction`` + created from the float kwargs) and then re-points ``self._fraction`` + at the persisted Parameter. The constraint graph still references the + temporary object, so subsequent ``mm.fraction = X`` mutations don't + propagate to ``_sld`` / ``_isld``. Re-running ``_materials_constraints`` + here points the graph at the live objects. + """ + instance = super().from_dict(obj_dict) + instance._materials_constraints() + return instance + # Representation @property def _dict_repr(self) -> dict[str, str]: @@ -219,19 +243,3 @@ def _dict_repr(self) -> dict[str, str]: 'material_b': self._material_b._dict_repr, } } - - def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, str]: - """Produces a cleaned dict using a custom as_dict method to skip necessary things. - - The resulting dict matches the parameters in __init__ - - Parameters - ---------- - skip : Optional[list[str]], optional - List of keys to skip. By default, None. - """ - this_dict = super().as_dict(skip=skip) - this_dict['material_a'] = self._material_a.as_dict(skip=skip) - this_dict['material_b'] = self._material_b.as_dict(skip=skip) - this_dict['fraction'] = self._fraction.as_dict(skip=skip) - return this_dict diff --git a/src/easyreflectometry/sample/elements/materials/material_solvated.py b/src/easyreflectometry/sample/elements/materials/material_solvated.py index 454904ba..d6121d10 100644 --- a/src/easyreflectometry/sample/elements/materials/material_solvated.py +++ b/src/easyreflectometry/sample/elements/materials/material_solvated.py @@ -72,6 +72,7 @@ def __init__( material_b=solvent, fraction=solvent_fraction, name=name, + unique_name=unique_name, interface=interface, ) if name is None: @@ -84,13 +85,7 @@ def material(self) -> Material: @material.setter def material(self, new_material: Material) -> None: - """Set the material. - - Parameters - ---------- - new_material : Material - Matrerial to be useed. - """ + """Set the material.""" self.material_a = new_material @property @@ -100,13 +95,7 @@ def solvent(self) -> Material: @solvent.setter def solvent(self, new_solvent: Material) -> None: - """Set the solvent. - - Parameters - ---------- - new_solvent : Material - Solvent to be used. - """ + """Set the solvent.""" self.material_b = new_solvent @property @@ -115,39 +104,58 @@ def solvent_fraction_parameter(self) -> Parameter: return self._fraction @property - def solvent_fraction(self) -> float: - """Get the fraction of layer described by the solvent. + def solvent_fraction(self) -> Parameter: + """The Parameter for the fraction of the layer described by the solvent. - This might be fraction of: - Solvation where solvent is within the layer - Patches of solvent in the layer where no material is present. + This might be the fraction of: + - solvation where solvent is within the layer, or + - patches of solvent in the layer where no material is present. """ - return self.fraction + return self._fraction @solvent_fraction.setter def solvent_fraction(self, solvent_fraction: float) -> None: - """Set the fraction of layer covered by the material. - - This might be fraction of: - Solvation where solvent is within the layer - Patches of solvent in the layer where no material is present. - - Parameters - ---------- - solvent_fraction : float - Fraction of layer described by the solvent. - """ - try: - self.fraction = solvent_fraction - if solvent_fraction < 0 or solvent_fraction > 1: - raise ValueError('solvent_fraction must be between 0 and 1') - except ValueError: + """Set the fraction of layer covered by the material.""" + if not isinstance(solvent_fraction, (int, float)): raise ValueError('solvent_fraction must be a float between 0 and 1') + if solvent_fraction < 0 or solvent_fraction > 1: + raise ValueError('solvent_fraction must be between 0 and 1') + self._fraction.value = solvent_fraction def _update_name(self) -> None: """Update name.""" self.name = self._material_a.name + ' in ' + self._material_b.name + # ----- deserialization ----- + + @classmethod + def from_dict(cls, obj_dict: dict) -> 'MaterialSolvated': + """Re-route the saved ``solvent_fraction`` Parameter onto ``_fraction``. + + :class:`ModelBase.from_dict` writes the saved Parameter to + ``_solvent_fraction`` because that's the constructor-arg name, but + the live `solvent_fraction` property returns ``self._fraction`` + (the field MaterialMixture maintains). Without this override the + saved fit metadata (fixed/bounds/etc.) is stranded on the unused + ``_solvent_fraction`` attribute and the active parameter keeps the + defaults from `__init__`. + + Also re-runs `_materials_constraints` so the parent MaterialMixture's + mixed `_sld` / `_isld` depend on the live `_fraction`, not the + temporary Parameter created from the float kwarg. + """ + instance = super().from_dict(obj_dict) + saved = instance.__dict__.pop('_solvent_fraction', None) + if saved is not None: + old = instance._fraction + instance._fraction = saved + try: + instance._global_object.map.prune(old.unique_name) + except (AttributeError, KeyError): + pass + instance._materials_constraints() + return instance + # Representation @property def _dict_repr(self) -> dict[str, str]: @@ -161,28 +169,3 @@ def _dict_repr(self) -> dict[str, str]: 'solvent': self.solvent._dict_repr, } } - - def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, str]: - """Produces a cleaned dict using a custom as_dict method to skip necessary things. - - The resulting dict matches the parameters in __init__ - - Parameters - ---------- - skip : Optional[list[str]], optional - List of keys to skip. By default, None. - """ - this_dict = super().as_dict(skip=skip) - this_dict['material'] = self.material.as_dict(skip=skip) - this_dict['solvent'] = self.solvent.as_dict(skip=skip) - this_dict['solvent_fraction'] = self._fraction.as_dict(skip=skip) - # Property and protected varible from material_mixture - del this_dict['material_a'] - del this_dict['_material_a'] - # Property and protected varible from material_mixture - del this_dict['material_b'] - del this_dict['_material_b'] - # Property and protected varible from material_mixture - del this_dict['fraction'] - del this_dict['_fraction'] - return this_dict diff --git a/src/easyreflectometry/summary/summary.py b/src/easyreflectometry/summary/summary.py index c06751af..2fc2846f 100644 --- a/src/easyreflectometry/summary/summary.py +++ b/src/easyreflectometry/summary/summary.py @@ -191,7 +191,7 @@ def _sample_section(self) -> str: # Get parameters directly from the model instead of using project.parameters model = self._project._models[self._project.current_model_index] - parameters = model.get_parameters() + parameters = model.get_all_parameters() for parameter in parameters: path = global_object.map.find_path(model.unique_name, parameter.unique_name) @@ -248,7 +248,7 @@ def _refinement_section(self) -> str: # Get parameters directly from the model model = self._project._models[self._project.current_model_index] - parameters = model.get_parameters() + parameters = model.get_all_parameters() num_free_params = sum(1 for parameter in parameters if parameter.free) num_fixed_params = sum(1 for parameter in parameters if not parameter.free) diff --git a/tests/model/test_model_collection.py b/tests/model/test_model_collection.py index 4a6fd521..68c47e17 100644 --- a/tests/model/test_model_collection.py +++ b/tests/model/test_model_collection.py @@ -65,7 +65,7 @@ def test_add_model_color_cycle(self): collection.add_model() assert collection[1].color == COLORS[1] - collection.remove(0) + collection.remove_at(0) collection.add_model() assert collection[0].color == COLORS[1] @@ -101,7 +101,7 @@ def test_delete_model(self): # Then collection = ModelCollection(model_1, model_2) - collection.remove(0) + collection.remove_at(0) # Expect assert len(collection) == 1 diff --git a/tests/sample/assemblies/test_bilayer.py b/tests/sample/assemblies/test_bilayer.py index 03949ee7..431ad866 100644 --- a/tests/sample/assemblies/test_bilayer.py +++ b/tests/sample/assemblies/test_bilayer.py @@ -119,7 +119,7 @@ def test_tail_layers_linked(self): # Initial values should match assert p.front_tail_layer.thickness.value == p.back_tail_layer.thickness.value - assert p.front_tail_layer.area_per_molecule == p.back_tail_layer.area_per_molecule + assert p.front_tail_layer.area_per_molecule.value == p.back_tail_layer.area_per_molecule.value # Change front tail thickness - back tail should follow p.front_tail_layer.thickness.value = 20.0 @@ -128,8 +128,8 @@ def test_tail_layers_linked(self): # Change front tail area per molecule - back tail should follow p.front_tail_layer.area_per_molecule = 55.0 - assert p.front_tail_layer.area_per_molecule == 55.0 - assert p.back_tail_layer.area_per_molecule == 55.0 + assert p.front_tail_layer.area_per_molecule.value == 55.0 + assert p.back_tail_layer.area_per_molecule.value == 55.0 def test_constrain_heads_enabled(self): """Test head thickness/area constraint when enabled.""" @@ -142,8 +142,8 @@ def test_constrain_heads_enabled(self): # Change front head area per molecule - back head should follow p.front_head_layer.area_per_molecule = 60.0 - assert p.front_head_layer.area_per_molecule == 60.0 - assert p.back_head_layer.area_per_molecule == 60.0 + assert p.front_head_layer.area_per_molecule.value == 60.0 + assert p.back_head_layer.area_per_molecule.value == 60.0 def test_constrain_heads_disabled(self): """Test heads are independent when constraint disabled.""" @@ -190,8 +190,8 @@ def test_head_hydration_independent(self): p.back_head_layer.solvent_fraction = 0.5 # They should remain independent - assert p.front_head_layer.solvent_fraction == 0.3 - assert p.back_head_layer.solvent_fraction == 0.5 + assert p.front_head_layer.solvent_fraction.value == 0.3 + assert p.back_head_layer.solvent_fraction.value == 0.5 def test_conformal_roughness_enabled(self): """Test all roughnesses are linked when conformal roughness enabled.""" diff --git a/tests/sample/assemblies/test_surfactant_layer.py b/tests/sample/assemblies/test_surfactant_layer.py index 62b92689..653b19ba 100644 --- a/tests/sample/assemblies/test_surfactant_layer.py +++ b/tests/sample/assemblies/test_surfactant_layer.py @@ -41,31 +41,31 @@ def test_from_pars(self): assert p.tail_layer.molecular_formula == 'C8O10H12P' assert p.tail_layer.thickness.value == 12 assert p.tail_layer.solvent.as_dict() == h2o.as_dict() - assert p.tail_layer.solvent_fraction == 0.5 - assert p.tail_layer.area_per_molecule == 50 + assert p.tail_layer.solvent_fraction.value == 0.5 + assert p.tail_layer.area_per_molecule.value == 50 assert p.tail_layer.roughness.value == 2 assert p.layers[1].name == 'A Test Head Layer' assert p.head_layer.name == 'A Test Head Layer' assert p.head_layer.molecular_formula == 'C10H24' assert p.head_layer.thickness.value == 10 assert p.head_layer.solvent.as_dict() == noth2o.as_dict() - assert p.head_layer.solvent_fraction == 0.2 - assert p.head_layer.area_per_molecule == 40 + assert p.head_layer.solvent_fraction.value == 0.2 + assert p.head_layer.area_per_molecule.value == 40 assert p.name == 'A Test' def test_constraint_area_per_molecule(self): p = SurfactantLayer() p.tail_layer._area_per_molecule.value = 30 - assert p.tail_layer.area_per_molecule == 30.0 - assert p.head_layer.area_per_molecule == 48.2 + assert p.tail_layer.area_per_molecule.value == 30.0 + assert p.head_layer.area_per_molecule.value == 48.2 assert p.constrain_area_per_molecule is False p.constrain_area_per_molecule = True - assert p.tail_layer.area_per_molecule == 30 - assert p.head_layer.area_per_molecule == 30 + assert p.tail_layer.area_per_molecule.value == 30 + assert p.head_layer.area_per_molecule.value == 30 assert p.constrain_area_per_molecule is True p.tail_layer._area_per_molecule.value = 40 - assert p.tail_layer.area_per_molecule == 40 - assert p.head_layer.area_per_molecule == 40 + assert p.tail_layer.area_per_molecule.value == 40 + assert p.head_layer.area_per_molecule.value == 40 def test_conformal_roughness(self): p = SurfactantLayer() diff --git a/tests/sample/collections/test_base_collection.py b/tests/sample/collections/test_base_collection.py index 44b8e208..50489240 100644 --- a/tests/sample/collections/test_base_collection.py +++ b/tests/sample/collections/test_base_collection.py @@ -160,7 +160,7 @@ def test_remove(self): p.append(Layer(name='layer_4')) # Then - p.remove(1) + p.remove_at(1) # Then assert len(p) == 3 diff --git a/tests/sample/elements/layers/test_layer_area_per_molecule.py b/tests/sample/elements/layers/test_layer_area_per_molecule.py index 4d0abf7d..0d466be7 100644 --- a/tests/sample/elements/layers/test_layer_area_per_molecule.py +++ b/tests/sample/elements/layers/test_layer_area_per_molecule.py @@ -18,7 +18,7 @@ class TestLayerAreaPerMolecule(unittest.TestCase): def test_default(self): p = LayerAreaPerMolecule() assert p.molecular_formula == 'C10H18NO8P' - assert p.area_per_molecule == 48.2 + assert p.area_per_molecule.value == 48.2 assert str(p._area_per_molecule.unit) == 'Å^2' assert p._area_per_molecule.fixed is True assert p.thickness.value == 10.0 @@ -33,7 +33,7 @@ def test_default(self): assert p.solvent.sld.value == 6.36 assert p.solvent.isld.value == 0 assert p.solvent.name == 'D2O' - assert p.solvent_fraction == 0.2 + assert p.solvent_fraction.value == 0.2 assert str(p.material._fraction.unit) == 'dimensionless' assert p.material._fraction.fixed is True @@ -49,12 +49,12 @@ def test_from_pars(self): name='PG/H2O', ) assert p.molecular_formula == 'C8O10H12P' - assert p.area_per_molecule == 50 + assert p.area_per_molecule.value == 50 assert p.thickness.value == 12 assert p.roughness.value == 2 assert p.solvent.sld.value == -0.561 assert p.solvent.isld.value == 0 - assert p.solvent_fraction == 0.5 + assert p.solvent_fraction.value == 0.5 def test_from_pars_constraint(self): h2o = Material(-0.561, 0, 'H2O') @@ -68,15 +68,15 @@ def test_from_pars_constraint(self): name='PG/H2O', ) assert p.molecular_formula == 'C8O10H12P' - assert p.area_per_molecule == 50 + assert p.area_per_molecule.value == 50 assert_almost_equal(p.material.sld, 0.31494833333333333) assert p.thickness.value == 12 assert p.roughness.value == 2 assert p.solvent.sld.value == -0.561 assert p.solvent.isld.value == 0 - assert p.solvent_fraction == 0.5 + assert p.solvent_fraction.value == 0.5 p.area_per_molecule = 30 - assert p.area_per_molecule == 30 + assert p.area_per_molecule.value == 30 assert_almost_equal(p.material.sld, 0.7119138888888887) p.thickness.value = 10 assert p.thickness.value == 10 @@ -95,24 +95,24 @@ def test_solvent_change(self): name='PG/H2O', ) assert p.molecular_formula == 'C8O10H12P' - assert p.area_per_molecule == 50 + assert p.area_per_molecule.value == 50 print(p.material) assert_almost_equal(p.material.sld, 0.31494833333333333) assert p.thickness.value == 12 assert p.roughness.value == 2 assert p.solvent.sld.value == -0.561 assert p.solvent.isld.value == 0 - assert p.solvent_fraction == 0.5 + assert p.solvent_fraction.value == 0.5 d2o = Material(6.335, 0, 'D2O') p.solvent = d2o assert p.molecular_formula == 'C8O10H12P' - assert p.area_per_molecule == 50 + assert p.area_per_molecule.value == 50 assert_almost_equal(p.material.sld, 3.762948333333333) assert p.thickness.value == 12 assert p.roughness.value == 2 assert p.solvent.sld.value == 6.335 assert p.solvent.isld.value == 0 - assert p.solvent_fraction == 0.5 + assert p.solvent_fraction.value == 0.5 def test_molecular_formula_change(self): h2o = Material(-0.561, 0, 'H2O') @@ -126,24 +126,24 @@ def test_molecular_formula_change(self): name='PG/H2O', ) assert p.molecular_formula == 'C8O10H12P' - assert p.area_per_molecule == 50 + assert p.area_per_molecule.value == 50 assert_almost_equal(p.material.sld, 0.31494833333333333) assert p.thickness.value == 12 assert p.roughness.value == 2 assert p.solvent.sld.value == -0.561 assert p.solvent.isld.value == 0 - assert p.solvent_fraction == 0.5 + assert p.solvent_fraction.value == 0.5 assert p.material.name == 'C8O10H12P in H2O' p.molecular_formula = 'C8O10D12P' assert p.molecular_formula == 'C8O10D12P' - assert p.area_per_molecule == 50 + assert p.area_per_molecule.value == 50 assert_almost_equal(p.material.sld, 1.3558483333333333) assert p.thickness.value == 12 assert p.roughness.value == 2 assert p.solvent.sld.value == -0.561 assert p.solvent.isld.value == 0 - assert p.solvent_fraction == 0.5 + assert p.solvent_fraction.value == 0.5 assert p.material.name == 'C8O10D12P in H2O' def test_dict_repr(self): @@ -185,3 +185,48 @@ def test_dict_round_trip(self): # Expect assert sorted(p.as_dict()) == sorted(q.as_dict()) + + def test_solvent_fraction_metadata_and_mutation_after_round_trip(self): + """Regression covering two bugs at once: + + - ``solvent_fraction`` is a constructor argument but its backing + storage is ``self.material._fraction`` (delegated through + ``MaterialSolvated``). Without an override, ``ModelBase.from_dict`` + would put the saved Parameter on an orphan ``_solvent_fraction`` + attribute and reset the live one to constructor defaults. + - ``__init__`` builds the molecule SLD constraint against the + *temporary* thickness / area_per_molecule Parameters; after + ``from_dict`` reattaches the saved ones, mutating them must still + propagate to ``material.material.sld``. + """ + p = LayerAreaPerMolecule( + molecular_formula='C10H18NO8P', + thickness=12.0, + solvent_fraction=0.3, + area_per_molecule=50.0, + roughness=2.0, + ) + p.solvent_fraction.fixed = False + p.solvent_fraction.min = 0.12 + + original_mol_sld = p.material.material.sld.value + p_dict = p.as_dict() + global_object.map._clear() + + q = LayerAreaPerMolecule.from_dict(p_dict) + + # solvent_fraction metadata preserved, no orphan field. + assert q.solvent_fraction.value == 0.3 + assert q.solvent_fraction.fixed is False + assert q.solvent_fraction.min == 0.12 + assert '_solvent_fraction' not in q.__dict__ + + # Molecule SLD constraint preserved. + assert_almost_equal(q.material.material.sld.value, original_mol_sld) + + # Mutate the independent parameters and verify the constraint chain + # propagates to the derived molecule SLD. + q.area_per_molecule = 25.0 # half APM doubles SLD + assert_almost_equal(q.material.material.sld.value, 2 * original_mol_sld) + q.thickness.value = 6.0 # half thickness doubles SLD again + assert_almost_equal(q.material.material.sld.value, 4 * original_mol_sld) diff --git a/tests/sample/elements/materials/test_material_density.py b/tests/sample/elements/materials/test_material_density.py index 99c6615a..0b424cad 100644 --- a/tests/sample/elements/materials/test_material_density.py +++ b/tests/sample/elements/materials/test_material_density.py @@ -64,3 +64,22 @@ def test_dict_round_trip(self): q = MaterialDensity.from_dict(p_dict) assert sorted(p.as_dict()) == sorted(q.as_dict()) + + def test_density_mutation_propagates_after_round_trip(self): + """Regression: after ``from_dict`` reattaches the saved ``_density`` + Parameter, mutating it must propagate to ``sld`` / ``isld`` (which + are constrained off it). The ``__init__``-time constraint references + the temporary constructor Parameter; ``from_dict`` rebuilds the + graph so subsequent mutations propagate correctly. + """ + p = MaterialDensity(chemical_structure='Si', density=2.33) + original_sld = p.sld.value + p_dict = p.as_dict() + global_object.map._clear() + + q = MaterialDensity.from_dict(p_dict) + assert_almost_equal(q.sld.value, original_sld) + + q.density = 4.66 + # SLD scales linearly with density (constraint: d * sl / mw, etc.) + assert_almost_equal(q.sld.value, 2 * original_sld) diff --git a/tests/sample/elements/materials/test_material_mixture.py b/tests/sample/elements/materials/test_material_mixture.py index 0e70a2b7..3c2a3f65 100644 --- a/tests/sample/elements/materials/test_material_mixture.py +++ b/tests/sample/elements/materials/test_material_mixture.py @@ -13,7 +13,7 @@ class TestMaterialMixture: def test_default(self) -> None: material_mixture = MaterialMixture() - assert material_mixture.fraction == 0.5 + assert material_mixture.fraction.value == 0.5 assert str(material_mixture._fraction.unit) == 'dimensionless' assert_almost_equal(material_mixture.sld, 4.186) assert_almost_equal(material_mixture.isld, 0) @@ -22,7 +22,7 @@ def test_default(self) -> None: def test_default_constraint(self) -> None: material_mixture = MaterialMixture() - assert material_mixture.fraction == 0.5 + assert material_mixture.fraction.value == 0.5 assert str(material_mixture._fraction.unit) == 'dimensionless' assert_almost_equal(material_mixture.sld, 4.186) assert_almost_equal(material_mixture.isld, 0) @@ -37,57 +37,57 @@ def test_fraction_constraint(self): p = Material() q = Material(6.908, -0.278, 'Boron') material_mixture = MaterialMixture(p, q, 0.2) - assert material_mixture.fraction == 0.2 + assert material_mixture.fraction.value == 0.2 assert_almost_equal(material_mixture.sld, 4.7304) assert_almost_equal(material_mixture.isld, -0.0556) material_mixture._fraction.value = 0.5 - assert material_mixture.fraction == 0.5 + assert material_mixture.fraction.value == 0.5 assert_almost_equal(material_mixture.sld, 5.54700) assert_almost_equal(material_mixture.isld, -0.1390) def test_material_a_change(self) -> None: material_mixture = MaterialMixture() - assert material_mixture.fraction == 0.5 + assert material_mixture.fraction.value == 0.5 assert str(material_mixture._fraction.unit) == 'dimensionless' assert_almost_equal(material_mixture.sld, 4.186) assert_almost_equal(material_mixture.isld, 0) q = Material(6.908, -0.278, 'Boron') material_mixture.material_a = q - assert material_mixture.fraction == 0.5 + assert material_mixture.fraction.value == 0.5 assert str(material_mixture._fraction.unit) == 'dimensionless' assert_almost_equal(material_mixture.sld, 5.54700) assert_almost_equal(material_mixture.isld, -0.1390) def test_material_b_change(self) -> None: material_mixture = MaterialMixture() - assert material_mixture.fraction == 0.5 + assert material_mixture.fraction.value == 0.5 assert str(material_mixture._fraction.unit) == 'dimensionless' assert_almost_equal(material_mixture.sld, 4.186) assert_almost_equal(material_mixture.isld, 0) q = Material(6.908, -0.278, 'Boron') material_mixture.material_b = q - assert material_mixture.fraction == 0.5 + assert material_mixture.fraction.value == 0.5 assert str(material_mixture._fraction.unit) == 'dimensionless' assert_almost_equal(material_mixture.sld, 5.54700) assert_almost_equal(material_mixture.isld, -0.1390) def test_material_b_change_double(self) -> None: material_mixture = MaterialMixture() - assert material_mixture.fraction == 0.5 + assert material_mixture.fraction.value == 0.5 assert str(material_mixture._fraction.unit) == 'dimensionless' assert_almost_equal(material_mixture.sld, 4.186) assert_almost_equal(material_mixture.isld, 0) q = Material(6.908, -0.278, 'Boron') material_mixture.material_b = q assert material_mixture.name == 'EasyMaterial/Boron' - assert material_mixture.fraction == 0.5 + assert material_mixture.fraction.value == 0.5 assert str(material_mixture._fraction.unit) == 'dimensionless' assert_almost_equal(material_mixture.sld, 5.54700) assert_almost_equal(material_mixture.isld, -0.1390) r = Material(0.00, 0.00, 'ACMW') material_mixture.material_b = r assert material_mixture.name == 'EasyMaterial/ACMW' - assert material_mixture.fraction == 0.5 + assert material_mixture.fraction.value == 0.5 assert str(material_mixture._fraction.unit) == 'dimensionless' assert_almost_equal(material_mixture.sld, 2.0930) assert_almost_equal(material_mixture.isld, 0.0000) @@ -96,7 +96,7 @@ def test_from_pars(self): p = Material() q = Material(6.908, -0.278, 'Boron') material_mixture = MaterialMixture(p, q, 0.2) - assert material_mixture.fraction == 0.2 + assert material_mixture.fraction.value == 0.2 assert str(material_mixture._fraction.unit) == 'dimensionless' assert_almost_equal(material_mixture.sld, 4.7304) assert_almost_equal(material_mixture.isld, -0.0556) @@ -142,3 +142,39 @@ def test_update_name(self) -> None: # Expect assert material_mixture.name == 'name_a/name_b' + + def test_calculator_binding_uses_mixed_sld(self) -> None: + """Regression: the calculator wrapper must bind to the mixture's own + ``_sld``/``_isld`` (the weighted average), not to either child material's + sld/isld parameter. Without an explicit ``_get_linkable_attributes`` + override the inherited dir-walk picks up the first matching child + parameter and the wrapper silently gets the wrong SLD. + """ + from easyreflectometry.calculators import CalculatorFactory + + interface = CalculatorFactory() + material_a = Material(sld=2.0, isld=0.0) + material_b = Material(sld=6.0, isld=0.0) + mixture = MaterialMixture(material_a, material_b, fraction=0.25, interface=interface) + + # 2 * 0.75 + 6 * 0.25 = 1.5 + 1.5 = 3.0 + assert_almost_equal(mixture.sld, 3.0) + wrapper_material = interface()._wrapper.storage['material'][mixture.unique_name] + assert_almost_equal(wrapper_material.real.value, 3.0) + assert_almost_equal(wrapper_material.imag.value, 0.0) + + def test_mutation_propagates_after_round_trip(self) -> None: + """Regression: after ``from_dict`` swaps in the saved ``_fraction`` + Parameter, the dependency graph for ``_sld``/``_isld`` must point at + the live ``_fraction`` (not the temp Parameter created from the + float kwarg in ``__init__``).""" + p = MaterialMixture(Material(sld=2.0), Material(sld=6.0), fraction=0.25) + p_dict = p.as_dict() + global_object.map._clear() + + q = MaterialMixture.from_dict(p_dict) + assert_almost_equal(q.sld, 3.0) + + q.fraction = 0.8 + # 2 * 0.2 + 6 * 0.8 = 0.4 + 4.8 = 5.2 + assert_almost_equal(q.sld, 5.2) diff --git a/tests/sample/elements/materials/test_material_solvated.py b/tests/sample/elements/materials/test_material_solvated.py index 3d4a9917..3af8e379 100644 --- a/tests/sample/elements/materials/test_material_solvated.py +++ b/tests/sample/elements/materials/test_material_solvated.py @@ -32,7 +32,7 @@ def test_init(self, material_solvated: MaterialSolvated) -> None: # When Then Expect assert material_solvated.material_a == self.material assert material_solvated.material_b == self.solvent - assert material_solvated.fraction == 0.1 + assert material_solvated.fraction.value == 0.1 assert material_solvated.name == 'name' assert material_solvated.interface == self.mock_interface self.mock_interface.generate_bindings.call_count == 2 @@ -69,14 +69,14 @@ def test_set_solvent(self, material_solvated: MaterialSolvated) -> None: def test_solvent_fraction(self, material_solvated: MaterialSolvated) -> None: # When Then Expect - assert material_solvated.solvent_fraction == 0.1 + assert material_solvated.solvent_fraction.value == 0.1 def test_set_solvent_fraction(self, material_solvated: MaterialSolvated) -> None: # When Then material_solvated.solvent_fraction = 1.0 # Expect - assert material_solvated.solvent_fraction == 1.0 + assert material_solvated.solvent_fraction.value == 1.0 def test_set_solvent_fraction_exception(self, material_solvated: MaterialSolvated) -> None: # When Then Expect @@ -133,3 +133,29 @@ def test_update_name(self, material_solvated: MaterialSolvated) -> None: # Expect assert material_solvated.name == 'name_a in name_b' + + def test_solvent_fraction_metadata_survives_round_trip(self) -> None: + """Regression: ``solvent_fraction`` is a constructor argument, but its + backing storage is ``_fraction`` (inherited from MaterialMixture). + ``ModelBase.from_dict`` would write the saved Parameter to + ``_solvent_fraction`` (an orphan), silently resetting the active + parameter to constructor defaults. We re-route to ``_fraction`` in + ``MaterialSolvated.from_dict``. + """ + material = Material(sld=6.36, isld=0, name='D2O') + solvent = Material(sld=-0.561, isld=0, name='H2O') + p = MaterialSolvated(material=material, solvent=solvent, solvent_fraction=0.3) + # Tweak fit metadata that the default would not have. + p.solvent_fraction.fixed = False + p.solvent_fraction.min = 0.12 + + p_dict = p.as_dict() + global_object.map._clear() + + q = MaterialSolvated.from_dict(p_dict) + + assert q.solvent_fraction.value == 0.3 + assert q.solvent_fraction.fixed is False + assert q.solvent_fraction.min == 0.12 + # No orphan field. + assert '_solvent_fraction' not in q.__dict__ diff --git a/tests/test_project.py b/tests/test_project.py index 73aeca85..96577d0c 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -377,11 +377,13 @@ def test_as_dict(self): keys.sort() assert keys == [ 'calculator', + 'file_format', 'fitter_minimizer', 'info', 'models', 'with_experiments', ] + assert project_dict['file_format'] == Project.FILE_FORMAT assert project_dict['info'] == { 'name': 'DefaultEasyReflectometryProject', 'short_description': 'Reflectometry, 1D', From a1eb36afe3bfc8929a27fd19d335e98da4793bd4 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Sun, 24 May 2026 22:42:15 +0200 Subject: [PATCH 2/3] PR review --- .../sample/assemblies/bilayer.py | 28 +++++++++++++------ .../layers/layer_area_per_molecule.py | 26 ++++++++--------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/easyreflectometry/sample/assemblies/bilayer.py b/src/easyreflectometry/sample/assemblies/bilayer.py index cf6edc75..995128ed 100644 --- a/src/easyreflectometry/sample/assemblies/bilayer.py +++ b/src/easyreflectometry/sample/assemblies/bilayer.py @@ -59,6 +59,7 @@ def __init__( front_head_layer: LayerAreaPerMolecule | None = None, front_tail_layer: LayerAreaPerMolecule | None = None, back_head_layer: LayerAreaPerMolecule | None = None, + back_tail_layer: LayerAreaPerMolecule | None = None, name: str = 'EasyBilayer', unique_name: str | None = None, constrain_heads: bool = True, @@ -73,10 +74,16 @@ def __init__( Layer representing the front head part of the bilayer. By default, None. front_tail_layer : LayerAreaPerMolecule | None, optional Layer representing the front tail part of the bilayer. - A back tail layer is created internally with its thickness, area per molecule, - and solvent fraction constrained to match this layer. By default, None. + The back tail layer's thickness, area per molecule, and solvent fraction are + constrained to match this layer. By default, None. back_head_layer : LayerAreaPerMolecule | None, optional Layer representing the back head part of the bilayer. By default, None. + back_tail_layer : LayerAreaPerMolecule | None, optional + Layer representing the back tail part of the bilayer. If omitted, a back tail + is created from the front tail (same molecular_formula, solvent, roughness, etc.). + Independent state (solvent, molecular_formula, name, roughness when + ``conformal_roughness`` is False) is preserved across serialization; the + structural parameters listed above are derived from the front tail. By default, None. name : str, optional Name for bilayer. By default, 'EasyBilayer'. unique_name : str | None, optional @@ -109,13 +116,16 @@ def __init__( interface=interface, ) - # Create back tail layer with initial values copied from the front tail. - # Its parameters will be constrained to the front tail after construction. - back_tail_layer = self._create_back_tail_layer( - front_tail_layer=front_tail_layer, - unique_name=unique_name, - interface=interface, - ) + # If no back tail is supplied, derive one from the front tail. The structural + # parameters (thickness, area_per_molecule, solvent_fraction) get constrained to + # the front tail in `_setup_tail_constraints` below regardless of which path + # produced this layer. + if back_tail_layer is None: + back_tail_layer = self._create_back_tail_layer( + front_tail_layer=front_tail_layer, + unique_name=unique_name, + interface=interface, + ) if back_head_layer is None: back_head_layer = self._create_default_head_layer( diff --git a/src/easyreflectometry/sample/elements/layers/layer_area_per_molecule.py b/src/easyreflectometry/sample/elements/layers/layer_area_per_molecule.py index 14b07b8c..32161dcc 100644 --- a/src/easyreflectometry/sample/elements/layers/layer_area_per_molecule.py +++ b/src/easyreflectometry/sample/elements/layers/layer_area_per_molecule.py @@ -137,22 +137,18 @@ def __init__( unique_name_prefix=f'{unique_name}_Isl', ) - # Constrain molecule.sld via scattering length / (thickness * area_per_molecule) - dependency_expression = 'scattering_length / (thickness * area_per_molecule) * 1e6' - dependency_map = { - 'scattering_length': scattering_length_real, - 'thickness': thickness, - 'area_per_molecule': area_per_molecule_param, - } - molecule_material.sld.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) - - # Same dependency under short variable names + # Constrain molecule.sld / .isld to scattering_length / (thickness * area_per_molecule). + # `_setup_sld_constraints` rebuilds the same expression after from_dict, so keep the + # variable names (`a`, `b`, `p`) consistent with that path. dependency_expression = 'a / (b*p) * 1e6' - dependency_map = {'a': scattering_length_real, 'b': thickness, 'p': area_per_molecule_param} - molecule_material.sld.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) - - dependency_map = {'a': scattering_length_imag, 'b': thickness, 'p': area_per_molecule_param} - molecule_material.isld.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) + molecule_material.sld.make_dependent_on( + dependency_expression=dependency_expression, + dependency_map={'a': scattering_length_real, 'b': thickness, 'p': area_per_molecule_param}, + ) + molecule_material.isld.make_dependent_on( + dependency_expression=dependency_expression, + dependency_map={'a': scattering_length_imag, 'b': thickness, 'p': area_per_molecule_param}, + ) solvated_molecule_material = MaterialSolvated( material=molecule_material, From 1a167a1cc564559cfd0afb3f060c9544168ea9c6 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Thu, 28 May 2026 10:56:53 +0200 Subject: [PATCH 3/3] more tests --- tests/model/test_model.py | 117 ++++++++ tests/model/test_model_collection.py | 39 +++ tests/sample/assemblies/test_base_assembly.py | 15 + .../collections/test_base_collection.py | 177 +++++++++++ tests/sample/test_base_core.py | 284 ++++++++++++++++++ tests/test_project.py | 38 +++ 6 files changed, 670 insertions(+) create mode 100644 tests/sample/test_base_core.py diff --git a/tests/model/test_model.py b/tests/model/test_model.py index 2d623512..b11149c1 100644 --- a/tests/model/test_model.py +++ b/tests/model/test_model.py @@ -430,3 +430,120 @@ def test_dict_round_trip(interface): model.interface().reflectity_profile([0.3], model.unique_name), model_from_dict.interface().reflectity_profile([0.3], model_from_dict.unique_name), ) + + +class TestModelPropertyAccessors: + """Tests for the new @property accessors introduced in the ModelBase/EasyList migration.""" + + def test_scale_setter_updates_value(self): + model = Model() + model.scale = 3.0 + assert model.scale.value == 3.0 + + def test_scale_getter_returns_parameter(self): + model = Model(scale=2.5) + from easyscience.variable import Parameter + + assert isinstance(model.scale, Parameter) + assert model.scale.value == 2.5 + + def test_background_setter_updates_value(self): + model = Model() + model.background = 1e-6 + assert model.background.value == 1e-6 + + def test_background_getter_returns_parameter(self): + model = Model(background=5e-6) + from easyscience.variable import Parameter + + assert isinstance(model.background, Parameter) + assert model.background.value == 5e-6 + + def test_sample_setter(self): + model = Model() + new_sample = Sample(name='NewSample') + model.sample = new_sample + assert model.sample.name == 'NewSample' + + def test_to_dict_includes_sample_and_resolution(self): + model = Model() + d = model.to_dict() + assert 'sample' in d + assert 'resolution_function' in d + assert 'interface' in d # interface is None, encoded as None + assert 'name' in d + + def test_to_dict_with_interface_name(self): + interface = CalculatorFactory() + model = Model(interface=interface) + d = model.to_dict() + assert d['interface'] == 'refnx' + + def test_to_dict_excludes_derived_fields(self): + model = Model() + d = model.to_dict() + # sample, resolution_function, interface are handled separately + assert 'sample' in d + # The super().to_dict() skip prevents these from being top-level + assert 'resolution_function' in d + assert 'interface' in d + + def test_as_dict_alias(self): + model = Model() + assert model.as_dict() == model.to_dict() + + def test_is_default_property(self): + model = Model() + assert model.is_default is False + model.is_default = True + assert model.is_default is True + + +class TestModelRoundTrip: + """Tests verifying serialization round-trip for the Model class.""" + + def test_basic_round_trip_preserves_name(self): + global_object.map._clear() + model = Model(name='MyModel') + d = model.as_dict() + global_object.map._clear() + restored = Model.from_dict(d) + assert restored.name == 'MyModel' + + def test_round_trip_preserves_scale_and_background(self): + global_object.map._clear() + model = Model(scale=2.0, background=1e-7) + d = model.as_dict() + global_object.map._clear() + restored = Model.from_dict(d) + assert restored.scale.value == 2.0 + assert restored.background.value == 1e-7 + + def test_round_trip_preserves_resolution_function(self): + global_object.map._clear() + model = Model(resolution_function=PercentageFwhm(3.0)) + d = model.as_dict() + global_object.map._clear() + restored = Model.from_dict(d) + assert restored._resolution_function.smearing(100) == 3.0 + + def test_round_trip_preserves_interface(self): + global_object.map._clear() + interface = CalculatorFactory() + model = Model(interface=interface) + d = model.as_dict() + global_object.map._clear() + restored = Model.from_dict(d) + assert restored.interface().name == 'refnx' + + def test_round_trip_preserves_is_default(self): + global_object.map._clear() + model = Model() + model.is_default = True + d = model.as_dict() + global_object.map._clear() + restored = Model.from_dict(d) + # Note: is_default is a runtime flag that may not survive round-trip + # because from_dict reconstructs via __init__ which resets _is_default. + # This test documents the current behaviour. + assert restored.is_default is False diff --git a/tests/model/test_model_collection.py b/tests/model/test_model_collection.py index 68c47e17..dda554b1 100644 --- a/tests/model/test_model_collection.py +++ b/tests/model/test_model_collection.py @@ -163,3 +163,42 @@ def test_legacy_from_dict_sets_color_index(self): restored.add_model() assert [model.color for model in restored] == [COLORS[0], COLORS[1]] + + def test_next_color_index_property(self): + """next_color_index should be accessible as a property for serialization.""" + collection = ModelCollection(populate_if_none=False) + collection.add_model() + idx = collection.next_color_index + assert isinstance(idx, int) + assert idx >= 0 + + def test_next_color_index_none_when_no_colors(self): + """When COLORS is empty, next_color_index returns 0.""" + # We can test the None case when COLORS has entries, it wraps + collection = ModelCollection(populate_if_none=False) + # Without adding models, the index should still be accessible + assert collection.next_color_index is not None + + def test_from_dict_preserves_data_count(self): + """from_dict should reconstruct the exact number of models.""" + global_object.map._clear() + model_1 = Model(name='M1') + model_2 = Model(name='M2') + p = ModelCollection(model_1, model_2) + d = p.as_dict() + global_object.map._clear() + q = ModelCollection.from_dict(d) + assert len(q) == 2 + + def test_from_dict_with_extra_data_entries(self): + """from_dict should handle data entries correctly.""" + global_object.map._clear() + model_1 = Model(name='M1') + model_2 = Model(name='M2') + p = ModelCollection(model_1, model_2) + d = p.as_dict() + global_object.map._clear() + q = ModelCollection.from_dict(d) + assert len(q) == 2 + assert q[0].name == 'M1' + assert q[1].name == 'M2' diff --git a/tests/sample/assemblies/test_base_assembly.py b/tests/sample/assemblies/test_base_assembly.py index 482cf669..84228eca 100644 --- a/tests/sample/assemblies/test_base_assembly.py +++ b/tests/sample/assemblies/test_base_assembly.py @@ -182,3 +182,18 @@ def test_set_back_layer_with_front(self, base_assembly: BaseAssembly) -> None: # Expect assert base_assembly.layers == [self.mock_layer_0, self.mock_layer_1] + + def test_layers_setter(self) -> None: + """The layers property setter should replace the layer list.""" + global_object.map._clear() + BaseAssembly.__abstractmethods__ = set() + assembly = BaseAssembly( + name='test', + type='type', + interface=MagicMock(), + layers=[MagicMock(), MagicMock()], + ) + new_layers = [MagicMock(), MagicMock(), MagicMock()] + assembly.layers = new_layers + assert assembly.layers == new_layers + assert len(assembly.layers) == 3 diff --git a/tests/sample/collections/test_base_collection.py b/tests/sample/collections/test_base_collection.py index 50489240..c3036bc9 100644 --- a/tests/sample/collections/test_base_collection.py +++ b/tests/sample/collections/test_base_collection.py @@ -3,6 +3,8 @@ from unittest.mock import MagicMock +import pytest + from easyreflectometry.sample.collections.base_collection import BaseCollection from easyreflectometry.sample.elements.layers.layer import Layer @@ -167,3 +169,178 @@ def test_remove(self): assert p[0].name == 'layer_1' assert p[1].name == 'layer_3' assert p[2].name == 'layer_4' + + # ---- new BaseCollection (EasyList-based) specific tests ---- + + def test_name_getter_and_setter(self): + """name property should be readable and writable.""" + p = BaseCollection('original', MagicMock()) + assert p.name == 'original' + p.name = 'changed' + assert p.name == 'changed' + + def test_data_property(self): + """data property should return a read-only copy of the internal list.""" + elem = Layer(name='layer') + p = BaseCollection('name', MagicMock(), elem) + data = p.data + assert len(data) == 1 + assert data[0].name == 'layer' + # Mutating the returned copy must not affect the collection + data.append(Layer(name='extra')) + assert len(p) == 1 + + def test_interface_propagates_to_existing_items(self): + """Setting interface after construction should propagate to all items.""" + mock_iface = MagicMock() + elem = Layer(name='layer') + # Pass interface=None explicitly and items as positional args + p = BaseCollection('name', None, elem) + assert p.interface is None + p.interface = mock_iface + # The interface setter propagates to items then calls generate_bindings on the mock + assert elem.interface is mock_iface + mock_iface.generate_bindings.assert_called() + + def test_interface_propagates_to_inserted_items(self): + """Items inserted after interface is set should receive the interface.""" + mock_iface = MagicMock() + p = BaseCollection('name', mock_iface) + elem = Layer(name='new_layer') + p.append(elem) + assert elem.interface is mock_iface + + def test_get_all_variables(self): + """get_all_variables should collect parameters from all items.""" + elem = Layer(name='layer') + p = BaseCollection('name', MagicMock(), elem) + variables = p.get_all_variables() + # A Layer has thickness, roughness, and the material's sld/isld + names = {v.name for v in variables if hasattr(v, 'name')} + assert 'thickness' in names + assert 'roughness' in names + + def test_get_all_parameters(self): + """get_all_parameters should filter to only Parameter instances.""" + elem = Layer(name='layer') + p = BaseCollection('name', MagicMock(), elem) + params = p.get_all_parameters() + for param in params: + assert param.__class__.__name__ == 'Parameter' + + def test_get_free_parameters(self): + """get_free_parameters should return only independent, non-fixed parameters.""" + elem = Layer(name='layer') + # By default thickness/roughness are fixed + p = BaseCollection('name', MagicMock(), elem) + free = p.get_free_parameters() + # By default all params are fixed, so empty + assert len(free) == 0 + # Unfix one + elem.thickness.fixed = False + free = p.get_free_parameters() + assert len(free) == 1 + assert free[0].name == 'thickness' + + def test_get_fit_parameters_alias(self): + """get_fit_parameters should be an alias for get_free_parameters.""" + elem = Layer(name='layer') + p = BaseCollection('name', MagicMock(), elem) + assert p.get_fit_parameters() == p.get_free_parameters() + + def test_get_parameters_shim(self): + """get_parameters should be a compatibility alias for get_all_parameters.""" + elem = Layer(name='layer') + p = BaseCollection('name', MagicMock(), elem) + assert p.get_parameters() == p.get_all_parameters() + + def test_get_linkable_attributes(self): + """_get_linkable_attributes should return get_all_variables.""" + elem = Layer(name='layer') + p = BaseCollection('name', MagicMock(), elem) + assert p._get_linkable_attributes() == p.get_all_variables() + + def test_to_dict_includes_data_and_name(self): + """to_dict should serialize data items and collection metadata.""" + elem = Layer(name='layer') + p = BaseCollection('name', MagicMock(), elem) + d = p.to_dict() + assert d['name'] == 'name' + assert len(d['data']) == 1 + assert d['data'][0]['name'] == 'layer' + + def test_to_dict_skips_interface(self): + """to_dict should exclude the interface field.""" + mock_iface = MagicMock() + p = BaseCollection('name', mock_iface) + d = p.to_dict() + assert 'interface' not in d + + def test_to_dict_skips_unique_name_by_default(self): + """to_dict should drop unique_name (matching legacy behaviour).""" + elem = Layer(name='layer') + p = BaseCollection('name', MagicMock(), elem) + d = p.to_dict() + assert 'unique_name' not in d + + def test_as_dict_is_alias_for_to_dict(self): + """as_dict should delegate to to_dict.""" + elem = Layer(name='layer') + p = BaseCollection('name', MagicMock(), elem) + assert p.as_dict() == p.to_dict() + + def test_deepcopy_round_trips(self): + """__deepcopy__ should produce an equivalent collection via from_dict.""" + import copy + + elem = Layer(name='layer') + # Use a concrete subclass (LayerCollection) that properly supports deepcopy + from easyreflectometry.sample.collections.layer_collection import LayerCollection + + p = LayerCollection(elem, name='test_layers') + p_copy = copy.deepcopy(p) + assert len(p_copy) == len(p) + assert p_copy[0].name == p[0].name + + def test_repr_handles_exception_gracefully(self): + """__repr__ should not crash even with items lacking _dict_repr.""" + mock_item = MagicMock() + # Deliberately make _dict_repr raise + del mock_item._dict_repr + p = BaseCollection('name', interface=None) + # Manually insert the mock item bypassing normal insert + p._data.append(mock_item) + # Should not raise + result = repr(p) + assert isinstance(result, str) + + def test_insert_rejects_non_integer_index(self): + """insert should raise TypeError for non-integer indices.""" + p = BaseCollection('name', interface=None) + with pytest.raises(TypeError, match='Index must be an integer'): + p.insert('not_an_int', Layer(name='x')) + + def test_duplicate_insert_is_warned(self): + """Inserting an already-present item should warn and skip.""" + import warnings + + elem = Layer(name='layer') + p = BaseCollection('name', MagicMock(), elem) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + p.append(elem) + assert len(w) == 1 + assert 'already in collection' in str(w[0].message) + # Length unchanged + assert len(p) == 1 + + def test_get_key_uses_name(self): + """_get_key should use the item's name property.""" + elem = Layer(name='mylayer') + p = BaseCollection('name', MagicMock(), elem) + assert p._get_key(elem) == 'mylayer' + + def test_has_interface_setter(self): + """_has_interface_setter should correctly detect interface-writable types.""" + assert BaseCollection._has_interface_setter(Layer) is True + assert BaseCollection._has_interface_setter(int) is False diff --git a/tests/sample/test_base_core.py b/tests/sample/test_base_core.py new file mode 100644 index 00000000..e9b32df7 --- /dev/null +++ b/tests/sample/test_base_core.py @@ -0,0 +1,284 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for BaseCore class — the new ModelBase-based foundation for sample-tree objects.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from easyscience import global_object +from easyscience.variable import Parameter + +from easyreflectometry.sample.base_core import BaseCore + +# --------------------------------------------------------------------------- +# Minimal concrete subclass for testing the abstract BaseCore +# --------------------------------------------------------------------------- + + +class _ConcreteCore(BaseCore): + """A non-abstract BaseCore that exposes a simple ``_dict_repr``.""" + + def __init__(self, name='TestCore', interface=None, unique_name=None, **kwargs): + super().__init__(name=name, interface=interface, unique_name=unique_name, **kwargs) + + @property + def _dict_repr(self) -> dict[str, str]: + return {self.name: {'type': 'concrete'}} + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestBaseCore: + """Direct unit tests for the BaseCore abstract base class.""" + + # ---- construction ---- + + def test_default_construction(self) -> None: + """A minimal concrete subclass should construct without errors.""" + obj = _ConcreteCore(name='Test') + assert obj.name == 'Test' + assert obj.interface is None + assert obj.user_data == {} + + def test_construction_with_interface(self) -> None: + """Passing an interface should trigger generate_bindings.""" + mock_iface = MagicMock() + obj = _ConcreteCore(name='WithIface', interface=mock_iface) + assert obj.interface is mock_iface + mock_iface.generate_bindings.assert_called_once_with(obj) + + def test_construction_with_unique_name(self) -> None: + """unique_name is passed through to ModelBase.""" + obj = _ConcreteCore(name='Uniq', unique_name='my_unique') + assert obj.unique_name == 'my_unique' + + def test_construction_kwargs_stored_as_attributes(self) -> None: + """Transitional kwargs path: extra kwargs become plain instance attrs.""" + child = Parameter('extra_param', 5.0) + obj = _ConcreteCore(name='Kwargs', extra=child, extra2=42) + assert obj.extra is child + assert obj.extra2 == 42 + + # ---- name property ---- + + def test_name_getter(self) -> None: + obj = _ConcreteCore(name='MyName') + assert obj.name == 'MyName' + + def test_name_setter(self) -> None: + obj = _ConcreteCore(name='Original') + obj.name = 'Changed' + assert obj.name == 'Changed' + + # ---- interface property ---- + + def test_interface_set_to_none(self) -> None: + obj = _ConcreteCore(name='NoIface') + obj.interface = None + assert obj.interface is None + + def test_interface_set_triggers_bindings(self) -> None: + obj = _ConcreteCore(name='Late') + mock_iface = MagicMock() + obj.interface = mock_iface + mock_iface.generate_bindings.assert_called_once_with(obj) + + def test_interface_setter_does_not_call_generate_bindings_for_none(self) -> None: + obj = _ConcreteCore(name='NoneIface') + # Setting to None should be safe (no generate_bindings call) + obj.interface = None + assert obj.interface is None + + # ---- generate_bindings ---- + + def test_generate_bindings_raises_when_interface_is_none(self) -> None: + obj = _ConcreteCore(name='NoIface') + with pytest.raises(AttributeError, match='Interface error'): + obj.generate_bindings() + + def test_generate_bindings_propagates_to_children(self) -> None: + """Children with an interface setter receive the parent's interface.""" + mock_iface = MagicMock() + child = _ConcreteCore(name='Child') + child._interface = None # reset so we can observe propagation + obj = _ConcreteCore(name='Parent', child=child) + obj.interface = mock_iface + # The child should have received the interface too. + assert child.interface is mock_iface + + def test_generate_bindings_propagates_to_parameter_children(self) -> None: + """Parameters stored as plain attrs should not break binding propagation.""" + mock_iface = MagicMock() + param = Parameter('p', 1.0) + obj = _ConcreteCore(name='WithParam', p=param) + obj.interface = mock_iface + mock_iface.generate_bindings.assert_called_once_with(obj) + + # ---- _iter_public_children ---- + + def test_iter_public_children_includes_class_attrs(self) -> None: + child = _ConcreteCore(name='Child') + obj = _ConcreteCore(name='Parent', child=child) + children = list(obj._iter_public_children()) + assert child in children + + def test_iter_public_children_includes_instance_attrs(self) -> None: + param = Parameter('p', 1.0) + obj = _ConcreteCore(name='Parent', p=param) + children = list(obj._iter_public_children()) + assert param in children + + def test_iter_public_children_excludes_private(self) -> None: + obj = _ConcreteCore(name='Parent') + obj._private_thing = 'secret' + children = list(obj._iter_public_children()) + names = [getattr(c, 'name', c) for c in children] + assert 'secret' not in names + + def test_iter_public_children_excludes_interface_and_name(self) -> None: + obj = _ConcreteCore(name='Parent') + children = list(obj._iter_public_children()) + assert obj.interface not in children + + def test_iter_public_children_no_duplicates(self) -> None: + """If a child appears both as a class attr and instance attr, only one copy.""" + child = _ConcreteCore(name='Child') + obj = _ConcreteCore(name='Parent', child=child) + # Also set as attr with same id + obj.duplicate_ref = child + children = list(obj._iter_public_children()) + # child should appear only once + assert children.count(child) == 1 + + # ---- _has_interface_setter ---- + + def test_has_interface_setter_true(self) -> None: + assert BaseCore._has_interface_setter(_ConcreteCore) is True + + def test_has_interface_setter_false_for_bare_object(self) -> None: + assert BaseCore._has_interface_setter(object) is False + + def test_has_interface_setter_false_for_parameter(self) -> None: + """Parameter doesn't have an interface property.""" + assert BaseCore._has_interface_setter(Parameter) is False + + # ---- compatibility shims ---- + + def test_get_linkable_attributes(self) -> None: + param = Parameter('p', 1.0) + obj = _ConcreteCore(name='Core', p=param) + result = obj._get_linkable_attributes() + assert param in result + + def test_get_parameters_shim(self) -> None: + param = Parameter('p', 1.0) + obj = _ConcreteCore(name='Core', p=param) + result = obj.get_parameters() + assert param in result + + def test_add_component(self) -> None: + obj = _ConcreteCore(name='Core') + comp = Parameter('comp', 42.0) + obj._add_component('my_comp', comp) + assert obj.my_comp is comp + + # ---- get_all_variables ---- + + def test_get_all_variables_includes_descriptors(self) -> None: + param = Parameter('p', 1.0) + obj = _ConcreteCore(name='Core', p=param) + result = obj.get_all_variables() + assert param in result + + def test_get_all_variables_recurses_into_children(self) -> None: + inner_param = Parameter('inner', 2.0) + child = _ConcreteCore(name='Child', p=inner_param) + obj = _ConcreteCore(name='Parent', child=child) + result = obj.get_all_variables() + assert inner_param in result + + def test_get_all_variables_no_duplicates_across_children(self) -> None: + param = Parameter('shared', 1.0) + child_a = _ConcreteCore(name='A', p=param) + child_b = _ConcreteCore(name='B', p=param) + obj = _ConcreteCore(name='Parent', a=child_a, b=child_b) + result = obj.get_all_variables() + assert result.count(param) == 1 + + # ---- to_dict / as_dict ---- + + def test_to_dict_skips_interface(self) -> None: + mock_iface = MagicMock() + obj = _ConcreteCore(name='Core', interface=mock_iface) + d = obj.to_dict() + assert 'interface' not in d + + def test_to_dict_skips_unique_name_by_default(self) -> None: + obj = _ConcreteCore(name='Core', unique_name='my_unique') + d = obj.to_dict() + assert 'unique_name' not in d + + def test_to_dict_includes_name(self) -> None: + obj = _ConcreteCore(name='MyName') + d = obj.to_dict() + assert d.get('name') == 'MyName' + + def test_as_dict_is_alias_for_to_dict(self) -> None: + obj = _ConcreteCore(name='Core') + assert obj.as_dict() == obj.to_dict() + + def test_to_dict_respects_custom_skip(self) -> None: + obj = _ConcreteCore(name='Core') + d = obj.to_dict(skip=['name']) + assert 'name' not in d + + def test_to_dict_skip_not_mutated_by_callee(self) -> None: + """Caller's skip list must not be mutated.""" + obj = _ConcreteCore(name='Core') + skip = ['name'] + obj.to_dict(skip=skip) + assert skip == ['name'] # not appended-to + + # ---- repr ---- + + def test_repr_returns_yaml_string(self) -> None: + obj = _ConcreteCore(name='Test') + r = repr(obj) + assert 'Test' in r + assert 'concrete' in r + + # ---- user_data ---- + + def test_user_data_is_dict(self) -> None: + obj = _ConcreteCore(name='Core') + obj.user_data['key'] = 'value' + assert obj.user_data['key'] == 'value' + + # ---- round-trip ---- + + def test_basic_round_trip_via_material(self) -> None: + """Round-trip through a real subclass (Material) to verify BaseCore serialization.""" + from easyreflectometry.sample.elements.materials.material import Material + + global_object.map._clear() + obj = Material(sld=2.0, isld=0.5, name='TestMat') + d = obj.to_dict() + global_object.map._clear() + + restored = Material.from_dict(d) + assert restored.name == 'TestMat' + assert restored.sld.value == 2.0 + assert restored.isld.value == 0.5 + + def test_round_trip_skips_interface(self) -> None: + """Round-trip via to_dict → from_dict should strip the interface.""" + global_object.map._clear() + obj = _ConcreteCore(name='WithIface') + d = obj.to_dict() + assert 'interface' not in d diff --git a/tests/test_project.py b/tests/test_project.py index 96577d0c..ec85b3f2 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -419,6 +419,44 @@ def remove_interface(d): remove_interface(project_dict['models']) assert project_dict['models'] == models_dict + def test_from_dict_missing_file_format_raises(self): + """Loading a dict without file_format should raise ValueError.""" + project = Project() + bad_dict = {'info': {}, 'with_experiments': False, 'models': {'data': []}} + with pytest.raises(ValueError, match='predates file_format=2'): + project.from_dict(bad_dict) + + def test_from_dict_wrong_file_format_raises(self): + """Loading a dict with an unsupported file_format should raise ValueError.""" + project = Project() + bad_dict = { + 'file_format': 99, + 'info': {}, + 'with_experiments': False, + 'models': {'data': []}, + } + with pytest.raises(ValueError, match='Unsupported project file_format'): + project.from_dict(bad_dict) + + def test_from_dict_correct_file_format_succeeds(self): + """Loading a dict with the correct file_format should work.""" + global_object.map._clear() + # Build a valid project dict with at least one model + src_project = Project() + src_project._info['name'] = 'Test' + src_project._info['short_description'] = 'Desc' + src_project._info['modified'] = '01.01.2025 00:00' + src_project.default_model() # ensures at least one model exists + src_project._with_experiments = False + good_dict = src_project.as_dict() + global_object.map._clear() + + project = Project() + project.from_dict(good_dict) + assert project._info['name'] == 'Test' + assert project._with_experiments is False + assert len(project._models) >= 1 + def test_as_dict_materials_not_in_model(self): # When project = Project()