Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion docs/docs/tutorials/components.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,30 @@
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be79f3fc",
"metadata": {},
"outputs": [],
"source": [
"expr = ExpressionComponent(\n",
" 'A*erf(B*x)',\n",
")\n",
"\n",
"expr.A = 1.0\n",
"expr.B = 0.5\n",
"\n",
"\n",
"x = np.linspace(-5, 5, 100)\n",
"y = expr.evaluate(x)\n",
"\n",
"plt.figure()\n",
"plt.plot(x, y, label='erf')\n",
"plt.legend()\n",
"plt.show()"
]
}
],
"metadata": {
Expand All @@ -168,7 +192,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.13"
"version": "3.14.4"
}
},
"nbformat": 4,
Expand Down
60 changes: 45 additions & 15 deletions src/easydynamics/sample_model/components/expression_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import sympy as sp
from easyscience.variable import Parameter
from scipy.special import erf

from easydynamics.sample_model.components.model_component import ModelComponent
from easydynamics.utils.utils import Numeric
Expand All @@ -20,11 +21,6 @@
class ExpressionComponent(ModelComponent):
"""
Model component defined by a symbolic expression.

Example: expr = ExpressionComponent( "A * exp(-(x - x0)**2 / (2*sigma**2))", parameters={"A":
10, "x0": 0, "sigma": 1}, )

expr.A = 5 y = expr.evaluate(x)
"""

# -------------------------
Expand Down Expand Up @@ -100,6 +96,18 @@ def __init__(
If the expression is invalid or does not contain 'x'.
TypeError
If any parameter value is not numeric.

Examples
--------
>>> expr = ExpressionComponent(
... 'A * exp(-(x - x0)**2 / (2*sigma**2))',
... parameters={'A': 10, 'x0': 0, 'sigma': 1},
... unit='meV',
... display_name='Gaussian Peak',
... )

>>> expr.A = 5
>>> y = expr.evaluate(x)
"""
super().__init__(unit=unit, display_name=display_name, unique_name=unique_name)

Expand Down Expand Up @@ -157,8 +165,11 @@ def __init__(

if parameters is not None:
for name, value in parameters.items():
if not isinstance(value, Numeric):
raise TypeError(f"Parameter '{name}' must be numeric")
if not isinstance(value, (Numeric, Parameter, dict)):
raise TypeError(
f"Parameter '{name}' must be numeric, "
f'a Parameter instance, or a dictionary, got {type(value).__name__}'
)
parameters = parameters or {}
self._parameters: dict[str, Parameter] = {}

Expand All @@ -168,20 +179,25 @@ def __init__(
continue

value = parameters.get(name, 1.0)
if isinstance(value, Parameter):
self._parameters[name] = value

self._parameters[name] = Parameter(
name=name,
value=value,
unit=self._unit,
)
elif isinstance(value, dict) and value.get('@class') == 'Parameter':
self._parameters[name] = Parameter.from_dict(value)
Comment on lines +185 to +186
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part is missing test coverage and is important (deserialization from external data)

Copy link
Copy Markdown
Member Author

@henrikjacobsenfys henrikjacobsenfys May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not covered by the test of copy? Copy uses the to and from dict methods.

I added this line specifically because copy fails, which it does because I take parameters as a dict.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, this is what copy does

    def __copy__(self) -> NewBase:
        """Return a copy of the object."""
        temp = self.to_dict(skip=['unique_name'])
        new_obj = self.__class__.from_dict(temp)
        return new_obj

else:
self._parameters[name] = Parameter(
name=name,
value=value,
unit=self._unit,
)

# Create numerical function
ordered_symbols = [sp.Symbol(name) for name in self._symbol_names]

self._func = sp.lambdify(
ordered_symbols,
self._expr,
modules=['numpy'],
modules=[{'erf': erf}, 'numpy'],
)

# -------------------------
Expand All @@ -190,7 +206,14 @@ def __init__(

@property
def expression(self) -> str:
"""Return the original expression string."""
"""
Return the original expression string.

Returns
-------
str
The original expression string provided at initialization.
"""
return self._expression_str

@expression.setter
Expand Down Expand Up @@ -334,7 +357,14 @@ def __dir__(self) -> list[str]:
return super().__dir__() + list(self._parameters.keys())

def __repr__(self) -> str:
"""Repr function."""
"""
Return a string representation of the ExpressionComponent.

Returns
-------
str
String representation of the ExpressionComponent.
"""
param_str = ', '.join(f'{k}={v.value}' for k, v in self._parameters.items())
return (
f'{self.__class__.__name__}(\n'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: 2026 EasyScience contributors
# SPDX-License-Identifier: BSD-3-Clause

from copy import copy

import numpy as np
import pytest
from easyscience.variable import Parameter
Expand Down Expand Up @@ -34,6 +36,14 @@ def test_init_without_parameters(self):
# EXPECT
assert expr.A.value == pytest.approx(1.0) # default

def test_init_with_parameter(self):
# WHEN THEN
A = Parameter('A', 3.0)
expr = ExpressionComponent('A * x', parameters={'A': A})

# EXPECT
assert expr.A.value == pytest.approx(3.0)

def test_invalid_expression_raises(self):
# WHEN THEN EXPECT
with pytest.raises(ValueError, match='Invalid expression'):
Expand Down Expand Up @@ -172,3 +182,30 @@ def test_reserved_name_not_parameter(self):

assert 'A' in names
assert 'x' not in names # x is reserved

def test_copy(self, expr: ExpressionComponent):
# WHEN THEN
expr_copy = copy(expr)

# EXPECT the copy is a new instance with the same properties
assert expr_copy is not expr
assert isinstance(expr_copy, ExpressionComponent)
assert expr_copy.expression == expr.expression
assert expr_copy.unit == expr.unit
assert expr_copy.display_name == expr.display_name

assert expr_copy.A.value == pytest.approx(expr.A.value)
assert expr_copy.x0.value == pytest.approx(expr.x0.value)
assert expr_copy.sigma.value == pytest.approx(expr.sigma.value)

def test_erf(self):
# WHEN
expr = ExpressionComponent('erf(x)')
x = np.array([-1.0, 0.0, 1.0])

# THEN
result = expr.evaluate(x)

# EXPECT
expected = np.array([-0.84270079, 0.0, 0.84270079]) # erf(-1), erf(0), erf(1)
np.testing.assert_allclose(result, expected, rtol=1e-5)
Loading