From 986fb2a600f92ef4a0c73df11f329fc6a1a76681 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Wed, 20 May 2026 11:09:19 +0200 Subject: [PATCH 01/15] updated the signature of RNGHierarchy class to accept MPI communicator as an argument --- litebird_sim/seeding.py | 26 ++++++++++++++------------ litebird_sim/simulations.py | 2 +- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/litebird_sim/seeding.py b/litebird_sim/seeding.py index 52a7f4254..d839b9226 100644 --- a/litebird_sim/seeding.py +++ b/litebird_sim/seeding.py @@ -6,6 +6,10 @@ import numpy as np from numpy.random import PCG64, Generator, SeedSequence +from mpi4py.MPI import Comm + +from .mpi import MPI_COMM_WORLD, _SerialMpiCommunicator + from .observations import Observation @@ -177,16 +181,14 @@ def regenerate_or_check_detector_generators( If the number of generators does not match the number of detectors. """ comm = observations[0].comm - if comm is not None: - rank = comm.rank - comm_size = comm.size - else: - rank = 0 - comm_size = 1 + if comm is None: + comm = _SerialMpiCommunicator() if user_seed is not None: - RNG_hierarchy = RNGHierarchy(user_seed, comm_size, observations[0].n_detectors) - dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(rank=rank) + RNG_hierarchy = RNGHierarchy(user_seed, comm, observations[0].n_detectors) + dets_random = RNG_hierarchy.get_detector_level_generators_on_rank( + rank=comm.rank + ) if user_seed is None and dets_random is None: raise ValueError("You should pass either `user_seed` or `dets_random`.") assert dets_random is not None # Guaranteed by the check above @@ -204,11 +206,11 @@ class RNGHierarchy: def __init__( self, base_seed: int, - comm_size: int | None = None, + comm: Comm = MPI_COMM_WORLD, num_detectors_per_rank: int | None = None, ): self.base_seed = base_seed - self.comm_size = None + self.comm_size = comm.size self.num_detectors_per_rank = None # self.root_seq = SeedSequence(base_seed) self.metadata = { @@ -216,8 +218,8 @@ def __init__( "save_format_version": self.SAVE_FORMAT_VERSION, } self.hierarchy = {} - if comm_size is not None: - self.build_mpi_layer(comm_size) + if self.comm_size is not None: + self.build_mpi_layer(self.comm_size) if num_detectors_per_rank is not None: self.build_detector_layer(num_detectors_per_rank) diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index 5ebada2da..f2547322f 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -534,7 +534,7 @@ def init_rng_hierarchy(self, random_seed): the user-provided `random_seed`, but can be invoked again to reseed the simulation at any point. """ - self.rng_hierarchy = RNGHierarchy(random_seed) + self.rng_hierarchy = RNGHierarchy(random_seed, self.mpi_comm) self.rng_hierarchy.build_mpi_layer(self.mpi_comm.size) self.random = self.rng_hierarchy.get_generator(self.mpi_comm.rank) From 1c38bec0f2ea42e6cc74334a372f81dafaa1faa3 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Wed, 20 May 2026 11:10:31 +0200 Subject: [PATCH 02/15] fixed the tests to use updated RNGHierarchy class --- test/test_gaindrifts.py | 2 +- test/test_nonlinearity.py | 6 +++--- test/test_seeding.py | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/test_gaindrifts.py b/test/test_gaindrifts.py index 4eaae8c1b..70fa0ba03 100644 --- a/test/test_gaindrifts.py +++ b/test/test_gaindrifts.py @@ -5,7 +5,7 @@ def get_RNG_hierarchy(seed, num_of_dets): rng_hierarchy = lbs.RNGHierarchy( - seed, comm_size=1, num_detectors_per_rank=num_of_dets + seed, comm=lbs.MPI_COMM_WORLD, num_detectors_per_rank=num_of_dets ) return rng_hierarchy diff --git a/test/test_nonlinearity.py b/test/test_nonlinearity.py index 86f4c226a..12f818958 100644 --- a/test/test_nonlinearity.py +++ b/test/test_nonlinearity.py @@ -43,7 +43,7 @@ def test_add_quadratic_nonlinearity(): # Applying non-linearity on the given TOD component of an `Observation` object RNG_hierarchy = lbs.RNGHierarchy( - random_seed, comm_size=1, num_detectors_per_rank=len(dets) + random_seed, comm=lbs.MPI_COMM_WORLD, num_detectors_per_rank=len(dets) ) dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0) lbs.apply_quadratic_nonlin_to_observations( @@ -55,7 +55,7 @@ def test_add_quadratic_nonlinearity(): # Applying non-linearity on the TOD arrays of the individual detectors. RNG_hierarchy = lbs.RNGHierarchy( - random_seed, comm_size=1, num_detectors_per_rank=len(dets) + random_seed, comm=lbs.MPI_COMM_WORLD, num_detectors_per_rank=len(dets) ) dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0) for idx, tod in enumerate(sim.observations[0].nl_2_det): @@ -76,7 +76,7 @@ def test_add_quadratic_nonlinearity(): # Check if non-linearity is applied correctly RNG_hierarchy = lbs.RNGHierarchy( - random_seed, comm_size=1, num_detectors_per_rank=len(dets) + random_seed, comm=lbs.MPI_COMM_WORLD, num_detectors_per_rank=len(dets) ) dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0) sim.observations[0].tod_origin = np.ones_like(sim.observations[0].tod) diff --git a/test/test_seeding.py b/test/test_seeding.py index 0456bf472..e80acbfcb 100644 --- a/test/test_seeding.py +++ b/test/test_seeding.py @@ -306,7 +306,9 @@ def test_detector_generators_regeneration(tmp_path): ) rng_hierarchy = lbs.RNGHierarchy( - 987654321, comm_size=1, num_detectors_per_rank=sim.observations[0].n_detectors + 987654321, + comm=lbs.MPI_COMM_WORLD, + num_detectors_per_rank=sim.observations[0].n_detectors, ) fresh_dets_random = rng_hierarchy.get_detector_level_generators_on_rank(0) From bc966de17cc995596dd1ae8e0c3382fa1bb22628 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Wed, 20 May 2026 11:35:24 +0200 Subject: [PATCH 03/15] avoid importing mpi4py when it is not installed --- litebird_sim/seeding.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/litebird_sim/seeding.py b/litebird_sim/seeding.py index d839b9226..64ecc7d2d 100644 --- a/litebird_sim/seeding.py +++ b/litebird_sim/seeding.py @@ -6,8 +6,6 @@ import numpy as np from numpy.random import PCG64, Generator, SeedSequence -from mpi4py.MPI import Comm - from .mpi import MPI_COMM_WORLD, _SerialMpiCommunicator from .observations import Observation @@ -206,7 +204,7 @@ class RNGHierarchy: def __init__( self, base_seed: int, - comm: Comm = MPI_COMM_WORLD, + comm=MPI_COMM_WORLD, num_detectors_per_rank: int | None = None, ): self.base_seed = base_seed From ca4819455856cd6a7160ebdd125904f8924f667b Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Wed, 27 May 2026 16:24:51 +0200 Subject: [PATCH 04/15] updated regenerate_or_check_detector_generators() so that it accepts an MPI communicator as an argument --- litebird_sim/seeding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litebird_sim/seeding.py b/litebird_sim/seeding.py index 64ecc7d2d..5f97409e1 100644 --- a/litebird_sim/seeding.py +++ b/litebird_sim/seeding.py @@ -147,6 +147,7 @@ def get_detector_level_generators_from_hierarchy( def regenerate_or_check_detector_generators( observations: list[Observation], + comm=None, user_seed: int | None = None, dets_random: list[Generator] | None = None, ) -> list[Generator]: @@ -178,7 +179,6 @@ def regenerate_or_check_detector_generators( AssertionError If the number of generators does not match the number of detectors. """ - comm = observations[0].comm if comm is None: comm = _SerialMpiCommunicator() From cbec4ad4aa518be5b012d0cb69ca1f140ec32758 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Wed, 27 May 2026 16:25:52 +0200 Subject: [PATCH 05/15] updated the calls to regenerate_or_check_detector_generators() to supply suitable comm argument --- litebird_sim/gaindrifts.py | 1 + litebird_sim/noise.py | 1 + litebird_sim/non_linearity.py | 1 + test/test_seeding.py | 3 +++ 4 files changed, 6 insertions(+) diff --git a/litebird_sim/gaindrifts.py b/litebird_sim/gaindrifts.py index c68f47e3a..85047b342 100644 --- a/litebird_sim/gaindrifts.py +++ b/litebird_sim/gaindrifts.py @@ -558,6 +558,7 @@ def apply_gaindrift_to_observations( ) dets_random = regenerate_or_check_detector_generators( observations=obs_list, + comm=None, user_seed=user_seed, dets_random=dets_random, ) diff --git a/litebird_sim/noise.py b/litebird_sim/noise.py index 8deaf97ba..9dd0e9722 100644 --- a/litebird_sim/noise.py +++ b/litebird_sim/noise.py @@ -401,6 +401,7 @@ def add_noise_to_observations( dets_random = regenerate_or_check_detector_generators( observations=obs_list, + comm=None, user_seed=user_seed, dets_random=dets_random, ) diff --git a/litebird_sim/non_linearity.py b/litebird_sim/non_linearity.py index dbea2fd14..b434d8624 100644 --- a/litebird_sim/non_linearity.py +++ b/litebird_sim/non_linearity.py @@ -185,6 +185,7 @@ def apply_quadratic_nonlin_to_observations( ) dets_random = regenerate_or_check_detector_generators( observations=obs_list, + comm=obs_list[0].comm_time_block, user_seed=user_seed, dets_random=dets_random, ) diff --git a/test/test_seeding.py b/test/test_seeding.py index e80acbfcb..2773fbc13 100644 --- a/test/test_seeding.py +++ b/test/test_seeding.py @@ -289,18 +289,21 @@ def test_detector_generators_regeneration(tmp_path): with pytest.raises(ValueError): _ = lbs.regenerate_or_check_detector_generators( observations=sim.observations, + comm=None, user_seed=None, dets_random=None, ) with pytest.raises(AssertionError): _ = lbs.regenerate_or_check_detector_generators( observations=sim.observations, + comm=None, user_seed=None, dets_random=[sim.dets_random[0], sim.dets_random[1]], ) # `user_seed` takes priority over the generators of the `Simulation`` regenerated_dets_random = lbs.regenerate_or_check_detector_generators( observations=sim.observations, + comm=None, user_seed=987654321, dets_random=sim.dets_random, ) From a871519b2458efd6e3a8cad9201279f269f92f22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Jos=C3=A9=20Gomes?= Date: Thu, 28 May 2026 20:54:33 +0200 Subject: [PATCH 06/15] make user seed mandatory in sim.apply_quadratic_nonlin, remove default rng hierarchy, and add test for this feature in test_mpi.py --- litebird_sim/simulations.py | 25 +++++++-------- test/test_mpi.py | 62 +++++++++++++++++++++++++++++++++++-- test/test_nonlinearity.py | 4 ++- 3 files changed, 75 insertions(+), 16 deletions(-) diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index f2547322f..8665b034c 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -33,15 +33,17 @@ add_convolved_sky_to_observations, ) from .beam_synthesis import generate_gauss_beam_alms +from .constants import NUMBA_NUM_THREADS_ENVVAR from .coordinates import CoordinateSystem -from .detectors import DetectorInfo, FreqChannelInfo, InstrumentInfo, UUID +from .detectors import UUID, DetectorInfo, FreqChannelInfo, InstrumentInfo from .dipole import DipoleType, add_dipole_to_observations from .distribute import distribute_evenly, distribute_optimally -from .gaindrifts import GainDriftType, GainDriftParams, apply_gaindrift_to_observations +from .gaindrifts import GainDriftParams, GainDriftType, apply_gaindrift_to_observations from .healpix import write_healpix_map_to_file from .hwp import HWP from .hwp_diff_emiss import add_2f_to_observations from .imo.imo import Imo +from .input_sky import SkyGenerationParams, SkyGenerator from .io import read_list_of_observations, write_list_of_observations from .mapmaking import ( BinnerResult, @@ -56,8 +58,8 @@ make_pair_differenced_map, save_destriper_results, ) -from .input_sky import SkyGenerator, SkyGenerationParams -from .mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID +from .maps_and_harmonics import HealpixMap, SphericalHarmonics +from .mpi import MPI_COMM_GRID, MPI_COMM_WORLD, MPI_ENABLED from .noise import add_noise_to_observations from .non_linearity import NonLinParams, apply_quadratic_nonlin_to_observations from .observations import Observation, TodDescription @@ -67,13 +69,13 @@ from .scanning import ScanningStrategy, SpinningScanningStrategy from .seeding import RNGHierarchy from .spacecraft import SpacecraftOrbit, spacecraft_pos_and_vel -from .maps_and_harmonics import SphericalHarmonics, HealpixMap from .units import Units from .version import ( - __version__ as litebird_sim_version, __author__ as litebird_sim_author, ) -from .constants import NUMBA_NUM_THREADS_ENVVAR +from .version import ( + __version__ as litebird_sim_version, +) DEFAULT_BASE_IMO_URL = "https://litebirdimo.ssdc.asi.it" @@ -2080,7 +2082,6 @@ def apply_quadratic_nonlin( user_seed: int | None = None, component: str = "tod", append_to_report: bool = False, - rng_hierarchy: RNGHierarchy | None = None, ): """A method to apply non-linearity to the observation. @@ -2089,10 +2090,9 @@ def apply_quadratic_nonlin( applies non-linearity to a list of :class:`.Observation` instance. Random number generators are obtained from the detector-level layer. As default it uses the `dets_random` field of a :class:`.Simulation` object for this. """ - if rng_hierarchy is None: - rng_hierarchy = self.rng_hierarchy - dets_random = rng_hierarchy.get_detector_level_generators_on_rank( - self.mpi_comm.rank + + assert user_seed is not None, ( + "user_seed must be given in apply_quadratic_nonlin." ) if nl_params is None: nl_params = NonLinParams() @@ -2102,7 +2102,6 @@ def apply_quadratic_nonlin( nl_params=nl_params, user_seed=user_seed, component=component, - dets_random=dets_random, ) if append_to_report and MPI_COMM_WORLD.rank == 0: diff --git a/test/test_mpi.py b/test/test_mpi.py index cd5cfdcf9..af8be3f89 100644 --- a/test/test_mpi.py +++ b/test/test_mpi.py @@ -6,9 +6,8 @@ from typing import Callable import astropy.time as astrotime -import numpy as np - import litebird_sim as lbs +import numpy as np from litebird_sim import MPI_COMM_WORLD @@ -703,3 +702,62 @@ def test_nullify_mpi(tmp_path): if MPI_COMM_WORLD.rank == 0: print("Running test function {}".format(str(cur_test_fn)), file=stderr) __run_test_in_same_folder(cur_test_fn) + + +def test_non_linearity_seeding(tmp_path): + """Check if the seed for each detector is consistent even when + different MPI tasks share the same detector + """ + + if not lbs.MPI_ENABLED: + return + + rank = lbs.MPI_COMM_WORLD.rank + + tmp_path = Path(tmp_path) + + start_time = 0 + time_span_s = 4 + sampling_hz = 1 + + sim = lbs.Simulation( + base_path=tmp_path, + start_time=start_time, + duration_s=time_span_s, + random_seed=12345, + imo=lbs.Imo(flatfile_location=lbs.PTEP_IMO_LOCATION), + ) + + det1 = lbs.DetectorInfo( + name="Dummy detector", + sampling_rate_hz=sampling_hz, + bandcenter_ghz=100.0, + quat=[0.0, 0.0, 0.0, 1.0], + ) + + det2 = lbs.DetectorInfo( + name="Dummy detector", + sampling_rate_hz=sampling_hz, + bandcenter_ghz=100.0, + quat=[0.0, 0.0, 0.0, 1.0], + ) + + sim.create_observations( + detectors=[det1, det2], + n_blocks_time=2, + n_blocks_det=1, + split_list_over_processes=False, + ) + + nl_params = lbs.NonLinParams( + sampling_gaussian_loc=0.0, sampling_gaussian_scale=2e-3 + ) + + sim.observations[0].tod = np.ones_like(sim.observations[0].tod) + + sim.apply_quadratic_nonlin(nl_params=nl_params, user_seed=1234) + + tods = lbs.MPI_COMM_WORLD.gather(sim.observations[0].tod, root=0) + + if rank == 0: + np.testing.assert_equal(tods[0], tods[1]) diff --git a/test/test_nonlinearity.py b/test/test_nonlinearity.py index 12f818958..2af7af0ca 100644 --- a/test/test_nonlinearity.py +++ b/test/test_nonlinearity.py @@ -1,6 +1,7 @@ from copy import deepcopy -import numpy as np + import litebird_sim as lbs +import numpy as np from astropy.time import Time @@ -39,6 +40,7 @@ def test_add_quadratic_nonlinearity(): sim.apply_quadratic_nonlin( nl_params=nl_params, component="nl_2_self", + user_seed=1234, ) # Applying non-linearity on the given TOD component of an `Observation` object From fc34257d9c1fe32d2d31b4631c55c0e037d6a74e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Jos=C3=A9=20Gomes?= Date: Thu, 28 May 2026 21:02:00 +0200 Subject: [PATCH 07/15] add new test to main() --- test/test_mpi.py | 59 ++++++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/test/test_mpi.py b/test/test_mpi.py index af8be3f89..4c2c66510 100644 --- a/test/test_mpi.py +++ b/test/test_mpi.py @@ -676,37 +676,9 @@ def test_nullify_mpi(tmp_path): sim.flush() -if __name__ == "__main__": - test_functions = [ - test_observation_time, - test_construction_from_detectors, - test_observation_tod_single_block, - test_observation_tod_two_block_time, - test_observation_tod_two_block_det, - test_observation_tod_set_blocks, - test_simulation_random, - ] - - for cur_test_fn in test_functions: - if MPI_COMM_WORLD.rank == 0: - print("Running test function {}".format(str(cur_test_fn)), file=stderr) - cur_test_fn() - - same_folder_test_functions = [ - test_write_hdf5_mpi, - test_issue314, - test_nullify_mpi, - ] - - for cur_test_fn in same_folder_test_functions: - if MPI_COMM_WORLD.rank == 0: - print("Running test function {}".format(str(cur_test_fn)), file=stderr) - __run_test_in_same_folder(cur_test_fn) - - def test_non_linearity_seeding(tmp_path): """Check if the seed for each detector is consistent even when - different MPI tasks share the same detector + different MPI tasks share the same detector in different time samples """ if not lbs.MPI_ENABLED: @@ -761,3 +733,32 @@ def test_non_linearity_seeding(tmp_path): if rank == 0: np.testing.assert_equal(tods[0], tods[1]) + + +if __name__ == "__main__": + test_functions = [ + test_observation_time, + test_construction_from_detectors, + test_observation_tod_single_block, + test_observation_tod_two_block_time, + test_observation_tod_two_block_det, + test_observation_tod_set_blocks, + test_simulation_random, + test_non_linearity_seeding, + ] + + for cur_test_fn in test_functions: + if MPI_COMM_WORLD.rank == 0: + print("Running test function {}".format(str(cur_test_fn)), file=stderr) + cur_test_fn() + + same_folder_test_functions = [ + test_write_hdf5_mpi, + test_issue314, + test_nullify_mpi, + ] + + for cur_test_fn in same_folder_test_functions: + if MPI_COMM_WORLD.rank == 0: + print("Running test function {}".format(str(cur_test_fn)), file=stderr) + __run_test_in_same_folder(cur_test_fn) From d79ca4a3720b329e0fb025da761042ae5d30db7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Jos=C3=A9=20Gomes?= Date: Thu, 28 May 2026 21:07:51 +0200 Subject: [PATCH 08/15] remove tmp_path arg from test --- test/test_mpi.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/test_mpi.py b/test/test_mpi.py index 4c2c66510..396247d07 100644 --- a/test/test_mpi.py +++ b/test/test_mpi.py @@ -676,7 +676,7 @@ def test_nullify_mpi(tmp_path): sim.flush() -def test_non_linearity_seeding(tmp_path): +def test_non_linearity_seeding(): """Check if the seed for each detector is consistent even when different MPI tasks share the same detector in different time samples """ @@ -686,14 +686,11 @@ def test_non_linearity_seeding(tmp_path): rank = lbs.MPI_COMM_WORLD.rank - tmp_path = Path(tmp_path) - start_time = 0 time_span_s = 4 sampling_hz = 1 sim = lbs.Simulation( - base_path=tmp_path, start_time=start_time, duration_s=time_span_s, random_seed=12345, From b313038255736eb8d1cb564b6b0d11459eb1a0ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Jos=C3=A9=20Gomes?= Date: Thu, 28 May 2026 21:13:30 +0200 Subject: [PATCH 09/15] fix test_nonlinearity.py --- test/test_nonlinearity.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_nonlinearity.py b/test/test_nonlinearity.py index 2af7af0ca..9604853be 100644 --- a/test/test_nonlinearity.py +++ b/test/test_nonlinearity.py @@ -53,6 +53,7 @@ def test_add_quadratic_nonlinearity(): nl_params=nl_params, component="nl_2_obs", dets_random=dets_random, + user_seed=1234, ) # Applying non-linearity on the TOD arrays of the individual detectors. @@ -65,6 +66,7 @@ def test_add_quadratic_nonlinearity(): tod_det=tod, nl_params=nl_params, random=dets_random[idx], + user_seed=1234, ) # Check if the three non-linear tods are equal From 8b489589a640eff6ff97be8c4031ecd7ff41b3e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Jos=C3=A9=20Gomes?= Date: Thu, 28 May 2026 20:25:29 +0100 Subject: [PATCH 10/15] fix test_nonlinearity.py --- test/test_nonlinearity.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_nonlinearity.py b/test/test_nonlinearity.py index 9604853be..4e1cc300f 100644 --- a/test/test_nonlinearity.py +++ b/test/test_nonlinearity.py @@ -40,7 +40,7 @@ def test_add_quadratic_nonlinearity(): sim.apply_quadratic_nonlin( nl_params=nl_params, component="nl_2_self", - user_seed=1234, + user_seed=random_seed, ) # Applying non-linearity on the given TOD component of an `Observation` object @@ -53,7 +53,7 @@ def test_add_quadratic_nonlinearity(): nl_params=nl_params, component="nl_2_obs", dets_random=dets_random, - user_seed=1234, + user_seed=random_seed, ) # Applying non-linearity on the TOD arrays of the individual detectors. @@ -66,7 +66,6 @@ def test_add_quadratic_nonlinearity(): tod_det=tod, nl_params=nl_params, random=dets_random[idx], - user_seed=1234, ) # Check if the three non-linear tods are equal From 12aabe758534ed1adab696c02cee9a6e7cf3f0b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Jos=C3=A9=20Gomes?= Date: Fri, 29 May 2026 10:03:59 +0100 Subject: [PATCH 11/15] skip test if communicator size is not enough --- test/test_mpi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_mpi.py b/test/test_mpi.py index 396247d07..dbc3efa2b 100644 --- a/test/test_mpi.py +++ b/test/test_mpi.py @@ -681,7 +681,7 @@ def test_non_linearity_seeding(): different MPI tasks share the same detector in different time samples """ - if not lbs.MPI_ENABLED: + if lbs.MPI_COMM_WORLD.size < 2: return rank = lbs.MPI_COMM_WORLD.rank From 35cb10f24937e2b35f606625082a774d573f3dea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Jos=C3=A9=20Gomes?= Date: Fri, 29 May 2026 10:42:25 +0100 Subject: [PATCH 12/15] update docs and docstrings --- docs/source/non_linearity.rst | 14 +++++++------- docs/source/random_numbers.rst | 19 ++++++++++++++++++- litebird_sim/non_linearity.py | 25 +++++++++++++------------ litebird_sim/simulations.py | 5 +++-- 4 files changed, 41 insertions(+), 22 deletions(-) diff --git a/docs/source/non_linearity.rst b/docs/source/non_linearity.rst index e3cf42dfe..ce0687f7d 100644 --- a/docs/source/non_linearity.rst +++ b/docs/source/non_linearity.rst @@ -29,7 +29,7 @@ To simulate a quadratic non-linearity, one can use the method of :class:`.Simulation` class :meth:`.Simulation.apply_quadratic_nonlin()`, or any of the low-level functions: :func:`.apply_quadratic_nonlin_to_observations()`, -:func:`.apply_quadratic_nonlin_for_one_detector()`. +:func:`.apply_quadratic_nonlin_for_one_detector()`. The examples below skip the simulation and observation creation for brevity. If needed, the implementation for those parts is explained in other sections of the docs. @@ -49,20 +49,20 @@ If needed, the implementation for those parts is explained in other sections of ) sim.prepare_pointings() - + (... fill your TOD scanning a map, or add dipole, noise, etc...) - - # Define nonlinearity amplitude distribution for the detectors, in 1/K units. + + # Define nonlinearity amplitude distribution for the detectors, in 1/K units. # The value is randomized for each detector, using either a user-provided seed or a list of # pre-initialized RNGs. - + nl_params = lbs.NonLinParams(sampling_gaussian_loc=-1.0, sampling_gaussian_scale=0.01) - # Applying nonlinearity using the `Simulation` class method + # Applying nonlinearity using the `Simulation` class method with a required user seed # By default, it modifies the ``Observation.tod``. If you want to apply it to some # other field of the :class:`.Observation` class, use `component` - sim.apply_quadratic_nonlin(nl_params = nl_params) + sim.apply_quadratic_nonlin(nl_params = nl_params, user_seed = 1234) .. _nl-api-reference: diff --git a/docs/source/random_numbers.rst b/docs/source/random_numbers.rst index 089454f90..b1d725278 100644 --- a/docs/source/random_numbers.rst +++ b/docs/source/random_numbers.rst @@ -157,12 +157,29 @@ Once the :class:`.RNGHierarchy` is instanciated and RNGs are created, it is suff generators_of_first_rank = rng_hierarchy.get_detector_level_generators_on_rank(rank) indeces = (1, 2) - generetor_second_rank_third_det = rng_hierarchy.get_generator(indeces) + generator_second_rank_third_det = rng_hierarchy.get_generator(indeces) Furthermore, :meth:`.RNGHierarchy.add_extra_layer` allows you to define a new level of the hierarchy from the current leaf layer. Finally, the :meth:`.RNGHierarchy.regenerate_or_check_detector_generators` allows you to quickly regenerate a set of new detector-level generators, or to check the validity of the existing one. This is useful if you passes a random seed specific for each systematics. This needs a list of :class:`.Observation` to work. + +Passing communicators to RNGHierarchy +============================ + +RNGHierarchy accepts an MPI communicator object via the comm argument. When a +communicator is provided, the class infers the communicator size and rank from it +internally. This allows creating RNG hierarchies that are consistent across ranks +of a supplied sub-communicator (for example, time-block or detector-block communicators). + +Passing a communicator is useful for example when you want RNGs that are identical for a given +detector across different global ranks, when sampling random numbers for values that are +detector-specific, such as non-linearities. When TODs are distributed such +that the same detector appears on multiple global ranks in different time blocks, +creating the RNGHierarchy with a time-block communicator will produce the same +detector-level RNGs on each rank in that time-block communicator, ensuring that +the random non linearity factor sampled for a given detector is identical across those ranks. + Saving and loading ------------------ diff --git a/litebird_sim/non_linearity.py b/litebird_sim/non_linearity.py index b434d8624..a63920c4d 100644 --- a/litebird_sim/non_linearity.py +++ b/litebird_sim/non_linearity.py @@ -1,6 +1,7 @@ -import numpy as np from dataclasses import dataclass +import numpy as np + from .observations import Observation from .seeding import regenerate_or_check_detector_generators @@ -123,7 +124,6 @@ def apply_quadratic_nonlin_to_observations( nl_params: NonLinParams | None = None, component: str = "tod", user_seed: int | None = None, - dets_random: list[np.random.Generator] | None = None, ): """ Apply a quadratic nonlinearity to some time-ordered data @@ -143,7 +143,7 @@ def apply_quadratic_nonlin_to_observations( # Ask `apply_quadratic_nonlin_to_observations` to store the nonlinear TOD # in `observations.nl_tod` - apply_quadratic_nonlin_to_observations(sim.observations, component="nl_tod") + apply_quadratic_nonlin_to_observations(sim.observations, component="nl_tod", user_seed=1234) Parameters ---------- @@ -154,14 +154,10 @@ def apply_quadratic_nonlin_to_observations( provided, a default configuration is used. component : str, optional Name of the TOD attribute to modify. Defaults to `"tod"`. - user_seed : int, optional - Base seed to build the RNG hierarchy and generate detector-level RNGs - that overwrite any eventual `dets_random`. Required if `dets_random` - is not provided. - dets_random : list of np.random.Generator, optional - List of per-detector random number generators. If not provided, and - `user_seed` is given, generators are created internally. One of - `user_seed` or `dets_random` must be provided. + user_seed : int or None + Base seed to build the RNG hierarchy and generate samples of the detector non-linearity factor, + independently of the MPI distribution across detectors and time. + The user is required to set this parameter. Raises ------ @@ -172,6 +168,11 @@ def apply_quadratic_nonlin_to_observations( AssertionError If the number of random generators does not match the number of detectors. """ + + assert user_seed is not None, ( + "user_seed must be given in apply_quadratic_nonlin_to_observations." + ) + if nl_params is None: nl_params = NonLinParams() @@ -183,11 +184,11 @@ def apply_quadratic_nonlin_to_observations( raise TypeError( "The parameter `observations` must be an `Observation` or a list of `Observation`." ) + dets_random = regenerate_or_check_detector_generators( observations=obs_list, comm=obs_list[0].comm_time_block, user_seed=user_seed, - dets_random=dets_random, ) # iterate through each observation diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index 8665b034c..13658b4d9 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -2082,18 +2082,19 @@ def apply_quadratic_nonlin( user_seed: int | None = None, component: str = "tod", append_to_report: bool = False, + rng_hierarchy: RNGHierarchy | None = None, ): """A method to apply non-linearity to the observation. This is a wrapper around :func:`.apply_quadratic_nonlin_to_observations()` that - applies non-linearity to a list of :class:`.Observation` instance. Random number generators are obtained from the detector-level layer. As default it uses - the `dets_random` field of a :class:`.Simulation` object for this. + applies non-linearity to a list of :class:`.Observation` instance. Random number generators are obtained for each detector, independently of the MPI distribution across detectors and time. Setting the `user_seed` argument is required for this. """ assert user_seed is not None, ( "user_seed must be given in apply_quadratic_nonlin." ) + if nl_params is None: nl_params = NonLinParams() From f35e0b02ed4d582605ca2a159d2705cdefb9054c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Jos=C3=A9=20Gomes?= Date: Fri, 29 May 2026 10:50:13 +0100 Subject: [PATCH 13/15] fix bug and remove unused docstrings --- litebird_sim/non_linearity.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/litebird_sim/non_linearity.py b/litebird_sim/non_linearity.py index a63920c4d..15767dd2e 100644 --- a/litebird_sim/non_linearity.py +++ b/litebird_sim/non_linearity.py @@ -66,20 +66,12 @@ def apply_quadratic_nonlin_for_one_detector( Args: - det_tod (np.ndarray): The TOD array corresponding to only one + tod_det (np.ndarray): The TOD array corresponding to only one detector. - det_name (str): The name of the detector to which the TOD belongs. - This name is used with ``user_seed`` to generate hash. This hash is used to - set random slope in case of linear drift, and randomized detector mismatch - in case of thermal gain drift. - nl_params (:class:`.NonLinParams`, optional): The non-linearity injection parameters object. Defaults to None. - user_seed (int, optional): A seed provided by the user. Defaults - to None. - random (np.random.Generator, optional): A random number generator. Defaults to None. """ @@ -124,6 +116,7 @@ def apply_quadratic_nonlin_to_observations( nl_params: NonLinParams | None = None, component: str = "tod", user_seed: int | None = None, + dets_random: list[np.random.Generator] | None = None, ): """ Apply a quadratic nonlinearity to some time-ordered data @@ -155,9 +148,13 @@ def apply_quadratic_nonlin_to_observations( component : str, optional Name of the TOD attribute to modify. Defaults to `"tod"`. user_seed : int or None - Base seed to build the RNG hierarchy and generate samples of the detector non-linearity factor, - independently of the MPI distribution across detectors and time. - The user is required to set this parameter. + Base seed to build the RNG hierarchy and generate samples of the detector non-linearity factor, + independently of the MPI distribution across detectors and time. + The user is required to set this parameter. + dets_random : list of np.random.Generator, optional + List of per-detector random number generators. If not provided, and + `user_seed` is given, generators are created internally. One of + `user_seed` or `dets_random` must be provided. Raises ------ @@ -184,11 +181,11 @@ def apply_quadratic_nonlin_to_observations( raise TypeError( "The parameter `observations` must be an `Observation` or a list of `Observation`." ) - dets_random = regenerate_or_check_detector_generators( observations=obs_list, comm=obs_list[0].comm_time_block, user_seed=user_seed, + dets_random=dets_random, ) # iterate through each observation From 2f19eee7888b66e602b68aa06c4f0ffae2079ca1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Jos=C3=A9=20Gomes?= Date: Fri, 29 May 2026 10:55:52 +0100 Subject: [PATCH 14/15] [skip ci] update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d07ee68b..158ca0836 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # HEAD +- Refactor RNGHierarchy to enable tying the random number generators to detectors [#523](https://github.com/litebird/litebird_sim/pull/523) + - Allow center pointings in get_pointings, remove _get_centered_pointings function and move its centering logic inside _get_pointings_array [#506](https://github.com/litebird/litebird_sim/pull/506) - Add complete HWP Jones formalism, including band integration [#499](https://github.com/litebird/litebird_sim/pull/499) From 83fe7eeb731ea24dc06abb2ca81daade6be015d4 Mon Sep 17 00:00:00 2001 From: Miguel Gomes <83699808+mj-gomes@users.noreply.github.com> Date: Fri, 29 May 2026 12:13:33 +0200 Subject: [PATCH 15/15] minor change in docs --- docs/source/random_numbers.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/random_numbers.rst b/docs/source/random_numbers.rst index b1d725278..46401e3a3 100644 --- a/docs/source/random_numbers.rst +++ b/docs/source/random_numbers.rst @@ -177,7 +177,7 @@ detector across different global ranks, when sampling random numbers for values detector-specific, such as non-linearities. When TODs are distributed such that the same detector appears on multiple global ranks in different time blocks, creating the RNGHierarchy with a time-block communicator will produce the same -detector-level RNGs on each rank in that time-block communicator, ensuring that +detector-level RNGs on each rank in that time-block communicator, ensuring, for example, that the random non linearity factor sampled for a given detector is identical across those ranks. Saving and loading