Skip to content

Commit 8e12050

Browse files
authored
Merge pull request #52 from MunchLab/smooth-ect
Adding the smooth ect
2 parents 5b128a2 + bf6da2d commit 8e12050

7 files changed

Lines changed: 156 additions & 7 deletions

File tree

doc_source/directions.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Directions
2+
3+
```{eval-rst}
4+
.. automodule:: ect.directions
5+
:members:
6+
```

doc_source/ect_on_graphs.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,8 @@
44
.. automodule:: ect.ect_graph
55
:members:
66
```
7+
8+
```{eval-rst}
9+
.. automodule:: ect.sect
10+
:members:
11+
```

doc_source/modules.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ Table of Contents
77

88
Embedded graphs <embed_graph.md>
99
Embedded CW complex <embed_cw.md>
10-
ECT on graphs <ect_on_graphs.md>
10+
ECT on graphs <ect_on_graphs.md>
11+
Directions <directions.md>

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "ect"
3-
version = "1.0.1"
3+
version = "1.0.2"
44
authors = [
55
{ name="Liz Munch", email="muncheli@msu.edu" },
66
]

src/ect/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
from .embed_graph import EmbeddedGraph
1313
from .embed_cw import EmbeddedCW
1414
from .directions import Directions
15+
from .sect import SECT
1516
from .utils import examples
1617

1718
__all__ = [
18-
'ECT',
19-
'EmbeddedGraph',
20-
'EmbeddedCW',
21-
'Directions',
22-
'examples',
19+
"ECT",
20+
"SECT",
21+
"EmbeddedGraph",
22+
"EmbeddedCW",
23+
"Directions",
24+
"examples",
2325
]

src/ect/sect.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from ect import ECT
2+
from .embed_graph import EmbeddedGraph
3+
from .embed_cw import EmbeddedCW
4+
from .directions import Directions
5+
from .results import ECTResult
6+
from typing import Optional, Union
7+
import numpy as np
8+
9+
10+
class SECT(ECT):
11+
"""
12+
A class to calculate the Smooth Euler Characteristic Transform (SECT).
13+
Inherits from ECT and applies smoothing to the final result.
14+
"""
15+
16+
def __init__(
17+
self,
18+
directions: Optional[Directions] = None,
19+
num_dirs: Optional[int] = None,
20+
num_thresh: Optional[int] = None,
21+
bound_radius: Optional[float] = None,
22+
thresholds: Optional[np.ndarray] = None,
23+
dtype=np.float32,
24+
):
25+
"""Initialize SECT calculator with smoothing parameter
26+
27+
Args:
28+
directions: Optional pre-configured Directions object
29+
num_dirs: Number of directions to sample (ignored if directions provided)
30+
num_thresh: Number of threshold values (required if directions not provided)
31+
bound_radius: Optional radius for bounding circle
32+
thresholds: Optional array of thresholds
33+
dtype: Data type for output array
34+
"""
35+
super().__init__(
36+
directions, num_dirs, num_thresh, bound_radius, thresholds, dtype
37+
)
38+
39+
def calculate(
40+
self,
41+
graph: Union[EmbeddedGraph, EmbeddedCW],
42+
theta: Optional[float] = None,
43+
override_bound_radius: Optional[float] = None,
44+
) -> ECTResult:
45+
"""Calculate Smooth Euler Characteristic Transform (SECT)
46+
47+
Args:
48+
graph: The input graph to calculate the SECT for
49+
theta: The angle in [0,2π] for the direction to calculate the SECT
50+
override_bound_radius: Optional override for bounding radius
51+
52+
Returns:
53+
ECTResult: The smoothed transform result containing the matrix,
54+
directions, and thresholds
55+
"""
56+
ect_result = super().calculate(graph, theta, override_bound_radius)
57+
return ECTResult(
58+
ect_result, ect_result.directions, ect_result.thresholds
59+
).smooth()

tests/test_sect.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import unittest
2+
import numpy as np
3+
from ect import SECT, ECT
4+
from ect.utils.examples import create_example_graph
5+
from ect.directions import Directions
6+
7+
8+
class TestSECT(unittest.TestCase):
9+
def setUp(self):
10+
"""Set up test fixtures"""
11+
self.graph = create_example_graph()
12+
self.num_dirs = 8
13+
self.num_thresh = 10
14+
self.sect = SECT(num_dirs=self.num_dirs, num_thresh=self.num_thresh)
15+
16+
def test_inheritance(self):
17+
"""Test that SECT properly inherits from ECT"""
18+
self.assertIsInstance(self.sect, ECT)
19+
self.assertTrue(hasattr(self.sect, "calculate"))
20+
21+
def test_calculate_output_shape(self):
22+
"""Test that SECT calculation returns correct shape"""
23+
result = self.sect.calculate(self.graph)
24+
25+
self.assertEqual(result.shape[0], self.num_dirs)
26+
self.assertEqual(result.shape[1], self.num_thresh)
27+
self.assertEqual(len(result.thresholds), self.num_thresh)
28+
self.assertEqual(len(result.directions), self.num_dirs)
29+
30+
def test_smoothing_effect(self):
31+
"""Test that smoothing is actually applied"""
32+
# Calculate both ECT and SECT
33+
ect = ECT(num_dirs=self.num_dirs, num_thresh=self.num_thresh)
34+
sect = SECT(num_dirs=self.num_dirs, num_thresh=self.num_thresh)
35+
36+
ect_result = ect.calculate(self.graph)
37+
sect_result = sect.calculate(self.graph)
38+
39+
# Verify results are different due to smoothing
40+
self.assertFalse(np.allclose(ect_result, sect_result))
41+
42+
# Verify smoothing preserves direction count
43+
self.assertEqual(
44+
np.sum(ect_result, axis=1).shape,
45+
np.sum(sect_result, axis=1).shape,
46+
)
47+
48+
def test_with_theta(self):
49+
"""Test SECT calculation with specific theta value"""
50+
theta = np.pi / 4
51+
result = self.sect.calculate(self.graph, theta=theta)
52+
53+
# Should only have one direction when theta is specified
54+
self.assertEqual(result.shape[0], 1)
55+
self.assertEqual(result.shape[1], self.num_thresh)
56+
57+
def test_with_override_radius(self):
58+
"""Test SECT calculation with override_bound_radius"""
59+
override_radius = 2.0
60+
result = self.sect.calculate(self.graph, override_bound_radius=override_radius)
61+
62+
# Check that thresholds are within the override radius
63+
self.assertLessEqual(np.max(np.abs(result.thresholds)), override_radius)
64+
65+
def test_smooth_matrix_properties(self):
66+
"""Test properties of the smoothed matrix"""
67+
result = self.sect.calculate(self.graph)
68+
69+
# Smoothed values should be finite
70+
self.assertTrue(np.all(np.isfinite(result)))
71+
72+
# Shape should be preserved after smoothing
73+
self.assertEqual(result.shape, (self.num_dirs, self.num_thresh))
74+
75+
# Verify result is float type after smoothing
76+
self.assertTrue(np.issubdtype(result.dtype, np.floating))

0 commit comments

Comments
 (0)