diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d07ee68..158ca083 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # HEAD +- Refactor RNGHierarchy to enable tying the random number generators to detectors [#523](https://github.com/litebird/litebird_sim/pull/523) + - Allow center pointings in get_pointings, remove _get_centered_pointings function and move its centering logic inside _get_pointings_array [#506](https://github.com/litebird/litebird_sim/pull/506) - Add complete HWP Jones formalism, including band integration [#499](https://github.com/litebird/litebird_sim/pull/499) diff --git a/docs/source/non_linearity.rst b/docs/source/non_linearity.rst index e3cf42df..ce0687f7 100644 --- a/docs/source/non_linearity.rst +++ b/docs/source/non_linearity.rst @@ -29,7 +29,7 @@ To simulate a quadratic non-linearity, one can use the method of :class:`.Simulation` class :meth:`.Simulation.apply_quadratic_nonlin()`, or any of the low-level functions: :func:`.apply_quadratic_nonlin_to_observations()`, -:func:`.apply_quadratic_nonlin_for_one_detector()`. +:func:`.apply_quadratic_nonlin_for_one_detector()`. The examples below skip the simulation and observation creation for brevity. If needed, the implementation for those parts is explained in other sections of the docs. @@ -49,20 +49,20 @@ If needed, the implementation for those parts is explained in other sections of ) sim.prepare_pointings() - + (... fill your TOD scanning a map, or add dipole, noise, etc...) - - # Define nonlinearity amplitude distribution for the detectors, in 1/K units. + + # Define nonlinearity amplitude distribution for the detectors, in 1/K units. # The value is randomized for each detector, using either a user-provided seed or a list of # pre-initialized RNGs. - + nl_params = lbs.NonLinParams(sampling_gaussian_loc=-1.0, sampling_gaussian_scale=0.01) - # Applying nonlinearity using the `Simulation` class method + # Applying nonlinearity using the `Simulation` class method with a required user seed # By default, it modifies the ``Observation.tod``. If you want to apply it to some # other field of the :class:`.Observation` class, use `component` - sim.apply_quadratic_nonlin(nl_params = nl_params) + sim.apply_quadratic_nonlin(nl_params = nl_params, user_seed = 1234) .. _nl-api-reference: diff --git a/docs/source/random_numbers.rst b/docs/source/random_numbers.rst index 089454f9..46401e3a 100644 --- a/docs/source/random_numbers.rst +++ b/docs/source/random_numbers.rst @@ -157,12 +157,29 @@ Once the :class:`.RNGHierarchy` is instanciated and RNGs are created, it is suff generators_of_first_rank = rng_hierarchy.get_detector_level_generators_on_rank(rank) indeces = (1, 2) - generetor_second_rank_third_det = rng_hierarchy.get_generator(indeces) + generator_second_rank_third_det = rng_hierarchy.get_generator(indeces) Furthermore, :meth:`.RNGHierarchy.add_extra_layer` allows you to define a new level of the hierarchy from the current leaf layer. Finally, the :meth:`.RNGHierarchy.regenerate_or_check_detector_generators` allows you to quickly regenerate a set of new detector-level generators, or to check the validity of the existing one. This is useful if you passes a random seed specific for each systematics. This needs a list of :class:`.Observation` to work. + +Passing communicators to RNGHierarchy +============================ + +RNGHierarchy accepts an MPI communicator object via the comm argument. When a +communicator is provided, the class infers the communicator size and rank from it +internally. This allows creating RNG hierarchies that are consistent across ranks +of a supplied sub-communicator (for example, time-block or detector-block communicators). + +Passing a communicator is useful for example when you want RNGs that are identical for a given +detector across different global ranks, when sampling random numbers for values that are +detector-specific, such as non-linearities. When TODs are distributed such +that the same detector appears on multiple global ranks in different time blocks, +creating the RNGHierarchy with a time-block communicator will produce the same +detector-level RNGs on each rank in that time-block communicator, ensuring, for example, that +the random non linearity factor sampled for a given detector is identical across those ranks. + Saving and loading ------------------ diff --git a/litebird_sim/gaindrifts.py b/litebird_sim/gaindrifts.py index c68f47e3..85047b34 100644 --- a/litebird_sim/gaindrifts.py +++ b/litebird_sim/gaindrifts.py @@ -558,6 +558,7 @@ def apply_gaindrift_to_observations( ) dets_random = regenerate_or_check_detector_generators( observations=obs_list, + comm=None, user_seed=user_seed, dets_random=dets_random, ) diff --git a/litebird_sim/noise.py b/litebird_sim/noise.py index 8deaf97b..9dd0e972 100644 --- a/litebird_sim/noise.py +++ b/litebird_sim/noise.py @@ -401,6 +401,7 @@ def add_noise_to_observations( dets_random = regenerate_or_check_detector_generators( observations=obs_list, + comm=None, user_seed=user_seed, dets_random=dets_random, ) diff --git a/litebird_sim/non_linearity.py b/litebird_sim/non_linearity.py index dbea2fd1..15767dd2 100644 --- a/litebird_sim/non_linearity.py +++ b/litebird_sim/non_linearity.py @@ -1,6 +1,7 @@ -import numpy as np from dataclasses import dataclass +import numpy as np + from .observations import Observation from .seeding import regenerate_or_check_detector_generators @@ -65,20 +66,12 @@ def apply_quadratic_nonlin_for_one_detector( Args: - det_tod (np.ndarray): The TOD array corresponding to only one + tod_det (np.ndarray): The TOD array corresponding to only one detector. - det_name (str): The name of the detector to which the TOD belongs. - This name is used with ``user_seed`` to generate hash. This hash is used to - set random slope in case of linear drift, and randomized detector mismatch - in case of thermal gain drift. - nl_params (:class:`.NonLinParams`, optional): The non-linearity injection parameters object. Defaults to None. - user_seed (int, optional): A seed provided by the user. Defaults - to None. - random (np.random.Generator, optional): A random number generator. Defaults to None. """ @@ -143,7 +136,7 @@ def apply_quadratic_nonlin_to_observations( # Ask `apply_quadratic_nonlin_to_observations` to store the nonlinear TOD # in `observations.nl_tod` - apply_quadratic_nonlin_to_observations(sim.observations, component="nl_tod") + apply_quadratic_nonlin_to_observations(sim.observations, component="nl_tod", user_seed=1234) Parameters ---------- @@ -154,10 +147,10 @@ def apply_quadratic_nonlin_to_observations( provided, a default configuration is used. component : str, optional Name of the TOD attribute to modify. Defaults to `"tod"`. - user_seed : int, optional - Base seed to build the RNG hierarchy and generate detector-level RNGs - that overwrite any eventual `dets_random`. Required if `dets_random` - is not provided. + user_seed : int or None + Base seed to build the RNG hierarchy and generate samples of the detector non-linearity factor, + independently of the MPI distribution across detectors and time. + The user is required to set this parameter. dets_random : list of np.random.Generator, optional List of per-detector random number generators. If not provided, and `user_seed` is given, generators are created internally. One of @@ -172,6 +165,11 @@ def apply_quadratic_nonlin_to_observations( AssertionError If the number of random generators does not match the number of detectors. """ + + assert user_seed is not None, ( + "user_seed must be given in apply_quadratic_nonlin_to_observations." + ) + if nl_params is None: nl_params = NonLinParams() @@ -185,6 +183,7 @@ def apply_quadratic_nonlin_to_observations( ) dets_random = regenerate_or_check_detector_generators( observations=obs_list, + comm=obs_list[0].comm_time_block, user_seed=user_seed, dets_random=dets_random, ) diff --git a/litebird_sim/seeding.py b/litebird_sim/seeding.py index 52a7f425..5f97409e 100644 --- a/litebird_sim/seeding.py +++ b/litebird_sim/seeding.py @@ -6,6 +6,8 @@ import numpy as np from numpy.random import PCG64, Generator, SeedSequence +from .mpi import MPI_COMM_WORLD, _SerialMpiCommunicator + from .observations import Observation @@ -145,6 +147,7 @@ def get_detector_level_generators_from_hierarchy( def regenerate_or_check_detector_generators( observations: list[Observation], + comm=None, user_seed: int | None = None, dets_random: list[Generator] | None = None, ) -> list[Generator]: @@ -176,17 +179,14 @@ def regenerate_or_check_detector_generators( AssertionError If the number of generators does not match the number of detectors. """ - comm = observations[0].comm - if comm is not None: - rank = comm.rank - comm_size = comm.size - else: - rank = 0 - comm_size = 1 + if comm is None: + comm = _SerialMpiCommunicator() if user_seed is not None: - RNG_hierarchy = RNGHierarchy(user_seed, comm_size, observations[0].n_detectors) - dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(rank=rank) + RNG_hierarchy = RNGHierarchy(user_seed, comm, observations[0].n_detectors) + dets_random = RNG_hierarchy.get_detector_level_generators_on_rank( + rank=comm.rank + ) if user_seed is None and dets_random is None: raise ValueError("You should pass either `user_seed` or `dets_random`.") assert dets_random is not None # Guaranteed by the check above @@ -204,11 +204,11 @@ class RNGHierarchy: def __init__( self, base_seed: int, - comm_size: int | None = None, + comm=MPI_COMM_WORLD, num_detectors_per_rank: int | None = None, ): self.base_seed = base_seed - self.comm_size = None + self.comm_size = comm.size self.num_detectors_per_rank = None # self.root_seq = SeedSequence(base_seed) self.metadata = { @@ -216,8 +216,8 @@ def __init__( "save_format_version": self.SAVE_FORMAT_VERSION, } self.hierarchy = {} - if comm_size is not None: - self.build_mpi_layer(comm_size) + if self.comm_size is not None: + self.build_mpi_layer(self.comm_size) if num_detectors_per_rank is not None: self.build_detector_layer(num_detectors_per_rank) diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index d9408137..6e7b46d8 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -539,7 +539,7 @@ def init_rng_hierarchy(self, random_seed): the user-provided `random_seed`, but can be invoked again to reseed the simulation at any point. """ - self.rng_hierarchy = RNGHierarchy(random_seed) + self.rng_hierarchy = RNGHierarchy(random_seed, self.mpi_comm) self.rng_hierarchy.build_mpi_layer(self.mpi_comm.size) self.random = self.rng_hierarchy.get_generator(self.mpi_comm.rank) @@ -2103,14 +2103,13 @@ def apply_quadratic_nonlin( This is a wrapper around :func:`.apply_quadratic_nonlin_to_observations()` that - applies non-linearity to a list of :class:`.Observation` instance. Random number generators are obtained from the detector-level layer. As default it uses - the `dets_random` field of a :class:`.Simulation` object for this. + applies non-linearity to a list of :class:`.Observation` instance. Random number generators are obtained for each detector, independently of the MPI distribution across detectors and time. Setting the `user_seed` argument is required for this. """ - if rng_hierarchy is None: - rng_hierarchy = self.rng_hierarchy - dets_random = rng_hierarchy.get_detector_level_generators_on_rank( - self.mpi_comm.rank + + assert user_seed is not None, ( + "user_seed must be given in apply_quadratic_nonlin." ) + if nl_params is None: nl_params = NonLinParams() @@ -2119,7 +2118,6 @@ def apply_quadratic_nonlin( nl_params=nl_params, user_seed=user_seed, component=component, - dets_random=dets_random, ) if append_to_report and MPI_COMM_WORLD.rank == 0: diff --git a/test/test_gaindrifts.py b/test/test_gaindrifts.py index 4eaae8c1..70fa0ba0 100644 --- a/test/test_gaindrifts.py +++ b/test/test_gaindrifts.py @@ -5,7 +5,7 @@ def get_RNG_hierarchy(seed, num_of_dets): rng_hierarchy = lbs.RNGHierarchy( - seed, comm_size=1, num_detectors_per_rank=num_of_dets + seed, comm=lbs.MPI_COMM_WORLD, num_detectors_per_rank=num_of_dets ) return rng_hierarchy diff --git a/test/test_mpi.py b/test/test_mpi.py index cd5cfdcf..dbc3efa2 100644 --- a/test/test_mpi.py +++ b/test/test_mpi.py @@ -6,9 +6,8 @@ from typing import Callable import astropy.time as astrotime -import numpy as np - import litebird_sim as lbs +import numpy as np from litebird_sim import MPI_COMM_WORLD @@ -677,6 +676,62 @@ def test_nullify_mpi(tmp_path): sim.flush() +def test_non_linearity_seeding(): + """Check if the seed for each detector is consistent even when + different MPI tasks share the same detector in different time samples + """ + + if lbs.MPI_COMM_WORLD.size < 2: + return + + rank = lbs.MPI_COMM_WORLD.rank + + start_time = 0 + time_span_s = 4 + sampling_hz = 1 + + sim = lbs.Simulation( + start_time=start_time, + duration_s=time_span_s, + random_seed=12345, + imo=lbs.Imo(flatfile_location=lbs.PTEP_IMO_LOCATION), + ) + + det1 = lbs.DetectorInfo( + name="Dummy detector", + sampling_rate_hz=sampling_hz, + bandcenter_ghz=100.0, + quat=[0.0, 0.0, 0.0, 1.0], + ) + + det2 = lbs.DetectorInfo( + name="Dummy detector", + sampling_rate_hz=sampling_hz, + bandcenter_ghz=100.0, + quat=[0.0, 0.0, 0.0, 1.0], + ) + + sim.create_observations( + detectors=[det1, det2], + n_blocks_time=2, + n_blocks_det=1, + split_list_over_processes=False, + ) + + nl_params = lbs.NonLinParams( + sampling_gaussian_loc=0.0, sampling_gaussian_scale=2e-3 + ) + + sim.observations[0].tod = np.ones_like(sim.observations[0].tod) + + sim.apply_quadratic_nonlin(nl_params=nl_params, user_seed=1234) + + tods = lbs.MPI_COMM_WORLD.gather(sim.observations[0].tod, root=0) + + if rank == 0: + np.testing.assert_equal(tods[0], tods[1]) + + if __name__ == "__main__": test_functions = [ test_observation_time, @@ -686,6 +741,7 @@ def test_nullify_mpi(tmp_path): test_observation_tod_two_block_det, test_observation_tod_set_blocks, test_simulation_random, + test_non_linearity_seeding, ] for cur_test_fn in test_functions: diff --git a/test/test_nonlinearity.py b/test/test_nonlinearity.py index 86f4c226..4e1cc300 100644 --- a/test/test_nonlinearity.py +++ b/test/test_nonlinearity.py @@ -1,6 +1,7 @@ from copy import deepcopy -import numpy as np + import litebird_sim as lbs +import numpy as np from astropy.time import Time @@ -39,11 +40,12 @@ def test_add_quadratic_nonlinearity(): sim.apply_quadratic_nonlin( nl_params=nl_params, component="nl_2_self", + user_seed=random_seed, ) # Applying non-linearity on the given TOD component of an `Observation` object RNG_hierarchy = lbs.RNGHierarchy( - random_seed, comm_size=1, num_detectors_per_rank=len(dets) + random_seed, comm=lbs.MPI_COMM_WORLD, num_detectors_per_rank=len(dets) ) dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0) lbs.apply_quadratic_nonlin_to_observations( @@ -51,11 +53,12 @@ def test_add_quadratic_nonlinearity(): nl_params=nl_params, component="nl_2_obs", dets_random=dets_random, + user_seed=random_seed, ) # Applying non-linearity on the TOD arrays of the individual detectors. RNG_hierarchy = lbs.RNGHierarchy( - random_seed, comm_size=1, num_detectors_per_rank=len(dets) + random_seed, comm=lbs.MPI_COMM_WORLD, num_detectors_per_rank=len(dets) ) dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0) for idx, tod in enumerate(sim.observations[0].nl_2_det): @@ -76,7 +79,7 @@ def test_add_quadratic_nonlinearity(): # Check if non-linearity is applied correctly RNG_hierarchy = lbs.RNGHierarchy( - random_seed, comm_size=1, num_detectors_per_rank=len(dets) + random_seed, comm=lbs.MPI_COMM_WORLD, num_detectors_per_rank=len(dets) ) dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0) sim.observations[0].tod_origin = np.ones_like(sim.observations[0].tod) diff --git a/test/test_seeding.py b/test/test_seeding.py index 0456bf47..2773fbc1 100644 --- a/test/test_seeding.py +++ b/test/test_seeding.py @@ -289,24 +289,29 @@ def test_detector_generators_regeneration(tmp_path): with pytest.raises(ValueError): _ = lbs.regenerate_or_check_detector_generators( observations=sim.observations, + comm=None, user_seed=None, dets_random=None, ) with pytest.raises(AssertionError): _ = lbs.regenerate_or_check_detector_generators( observations=sim.observations, + comm=None, user_seed=None, dets_random=[sim.dets_random[0], sim.dets_random[1]], ) # `user_seed` takes priority over the generators of the `Simulation`` regenerated_dets_random = lbs.regenerate_or_check_detector_generators( observations=sim.observations, + comm=None, user_seed=987654321, dets_random=sim.dets_random, ) rng_hierarchy = lbs.RNGHierarchy( - 987654321, comm_size=1, num_detectors_per_rank=sim.observations[0].n_detectors + 987654321, + comm=lbs.MPI_COMM_WORLD, + num_detectors_per_rank=sim.observations[0].n_detectors, ) fresh_dets_random = rng_hierarchy.get_detector_level_generators_on_rank(0)