Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# HEAD

- Refactor RNGHierarchy to enable tying the random number generators to detectors [#523](https://github.com/litebird/litebird_sim/pull/523)

- Allow center pointings in get_pointings, remove _get_centered_pointings function and move its centering logic inside _get_pointings_array [#506](https://github.com/litebird/litebird_sim/pull/506)

- Add complete HWP Jones formalism, including band integration [#499](https://github.com/litebird/litebird_sim/pull/499)
Expand Down
14 changes: 7 additions & 7 deletions docs/source/non_linearity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ To simulate a quadratic non-linearity, one can use the method of
:class:`.Simulation` class
:meth:`.Simulation.apply_quadratic_nonlin()`, or any of the low-level
functions: :func:`.apply_quadratic_nonlin_to_observations()`,
:func:`.apply_quadratic_nonlin_for_one_detector()`.
:func:`.apply_quadratic_nonlin_for_one_detector()`.

The examples below skip the simulation and observation creation for brevity.
If needed, the implementation for those parts is explained in other sections of the docs.
Expand All @@ -49,20 +49,20 @@ If needed, the implementation for those parts is explained in other sections of
)

sim.prepare_pointings()

(... fill your TOD scanning a map, or add dipole, noise, etc...)
# Define nonlinearity amplitude distribution for the detectors, in 1/K units.

# Define nonlinearity amplitude distribution for the detectors, in 1/K units.
# The value is randomized for each detector, using either a user-provided seed or a list of
# pre-initialized RNGs.

nl_params = lbs.NonLinParams(sampling_gaussian_loc=-1.0, sampling_gaussian_scale=0.01)

# Applying nonlinearity using the `Simulation` class method
# Applying nonlinearity using the `Simulation` class method with a required user seed
# By default, it modifies the ``Observation.tod``. If you want to apply it to some
# other field of the :class:`.Observation` class, use `component`

sim.apply_quadratic_nonlin(nl_params = nl_params)
sim.apply_quadratic_nonlin(nl_params = nl_params, user_seed = 1234)


.. _nl-api-reference:
Expand Down
19 changes: 18 additions & 1 deletion docs/source/random_numbers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,29 @@ Once the :class:`.RNGHierarchy` is instanciated and RNGs are created, it is suff
generators_of_first_rank = rng_hierarchy.get_detector_level_generators_on_rank(rank)

indeces = (1, 2)
generetor_second_rank_third_det = rng_hierarchy.get_generator(indeces)
generator_second_rank_third_det = rng_hierarchy.get_generator(indeces)

Furthermore, :meth:`.RNGHierarchy.add_extra_layer` allows you to define a new level of the hierarchy from the current leaf layer.

Finally, the :meth:`.RNGHierarchy.regenerate_or_check_detector_generators` allows you to quickly regenerate a set of new detector-level generators, or to check the validity of the existing one. This is useful if you passes a random seed specific for each systematics. This needs a list of :class:`.Observation` to work.


Passing communicators to RNGHierarchy
============================

RNGHierarchy accepts an MPI communicator object via the comm argument. When a
communicator is provided, the class infers the communicator size and rank from it
internally. This allows creating RNG hierarchies that are consistent across ranks
of a supplied sub-communicator (for example, time-block or detector-block communicators).

Passing a communicator is useful for example when you want RNGs that are identical for a given
detector across different global ranks, when sampling random numbers for values that are
detector-specific, such as non-linearities. When TODs are distributed such
that the same detector appears on multiple global ranks in different time blocks,
creating the RNGHierarchy with a time-block communicator will produce the same
detector-level RNGs on each rank in that time-block communicator, ensuring, for example, that
the random non linearity factor sampled for a given detector is identical across those ranks.

Saving and loading
------------------

Expand Down
1 change: 1 addition & 0 deletions litebird_sim/gaindrifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def apply_gaindrift_to_observations(
)
dets_random = regenerate_or_check_detector_generators(
observations=obs_list,
comm=None,
user_seed=user_seed,
dets_random=dets_random,
)
Expand Down
1 change: 1 addition & 0 deletions litebird_sim/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def add_noise_to_observations(

dets_random = regenerate_or_check_detector_generators(
observations=obs_list,
comm=None,
user_seed=user_seed,
dets_random=dets_random,
)
Expand Down
29 changes: 14 additions & 15 deletions litebird_sim/non_linearity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from dataclasses import dataclass

import numpy as np

from .observations import Observation
from .seeding import regenerate_or_check_detector_generators

Expand Down Expand Up @@ -65,20 +66,12 @@ def apply_quadratic_nonlin_for_one_detector(

Args:

det_tod (np.ndarray): The TOD array corresponding to only one
tod_det (np.ndarray): The TOD array corresponding to only one
detector.

det_name (str): The name of the detector to which the TOD belongs.
This name is used with ``user_seed`` to generate hash. This hash is used to
set random slope in case of linear drift, and randomized detector mismatch
in case of thermal gain drift.

nl_params (:class:`.NonLinParams`, optional): The non-linearity
injection parameters object. Defaults to None.

user_seed (int, optional): A seed provided by the user. Defaults
to None.

random (np.random.Generator, optional): A random number generator.
Defaults to None.
"""
Expand Down Expand Up @@ -143,7 +136,7 @@ def apply_quadratic_nonlin_to_observations(

# Ask `apply_quadratic_nonlin_to_observations` to store the nonlinear TOD
# in `observations.nl_tod`
apply_quadratic_nonlin_to_observations(sim.observations, component="nl_tod")
apply_quadratic_nonlin_to_observations(sim.observations, component="nl_tod", user_seed=1234)

Parameters
----------
Expand All @@ -154,10 +147,10 @@ def apply_quadratic_nonlin_to_observations(
provided, a default configuration is used.
component : str, optional
Name of the TOD attribute to modify. Defaults to `"tod"`.
user_seed : int, optional
Base seed to build the RNG hierarchy and generate detector-level RNGs
that overwrite any eventual `dets_random`. Required if `dets_random`
is not provided.
user_seed : int or None
Base seed to build the RNG hierarchy and generate samples of the detector non-linearity factor,
independently of the MPI distribution across detectors and time.
The user is required to set this parameter.
dets_random : list of np.random.Generator, optional
List of per-detector random number generators. If not provided, and
`user_seed` is given, generators are created internally. One of
Expand All @@ -172,6 +165,11 @@ def apply_quadratic_nonlin_to_observations(
AssertionError
If the number of random generators does not match the number of detectors.
"""

assert user_seed is not None, (
"user_seed must be given in apply_quadratic_nonlin_to_observations."
)

if nl_params is None:
nl_params = NonLinParams()

Expand All @@ -185,6 +183,7 @@ def apply_quadratic_nonlin_to_observations(
)
dets_random = regenerate_or_check_detector_generators(
observations=obs_list,
comm=obs_list[0].comm_time_block,
user_seed=user_seed,
dets_random=dets_random,
)
Expand Down
26 changes: 13 additions & 13 deletions litebird_sim/seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
from numpy.random import PCG64, Generator, SeedSequence

from .mpi import MPI_COMM_WORLD, _SerialMpiCommunicator

from .observations import Observation


Expand Down Expand Up @@ -145,6 +147,7 @@ def get_detector_level_generators_from_hierarchy(

def regenerate_or_check_detector_generators(
observations: list[Observation],
comm=None,
user_seed: int | None = None,
dets_random: list[Generator] | None = None,
) -> list[Generator]:
Expand Down Expand Up @@ -176,17 +179,14 @@ def regenerate_or_check_detector_generators(
AssertionError
If the number of generators does not match the number of detectors.
"""
comm = observations[0].comm
if comm is not None:
rank = comm.rank
comm_size = comm.size
else:
rank = 0
comm_size = 1
if comm is None:
comm = _SerialMpiCommunicator()

if user_seed is not None:
RNG_hierarchy = RNGHierarchy(user_seed, comm_size, observations[0].n_detectors)
dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(rank=rank)
RNG_hierarchy = RNGHierarchy(user_seed, comm, observations[0].n_detectors)
dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(
rank=comm.rank
)
if user_seed is None and dets_random is None:
raise ValueError("You should pass either `user_seed` or `dets_random`.")
assert dets_random is not None # Guaranteed by the check above
Expand All @@ -204,20 +204,20 @@ class RNGHierarchy:
def __init__(
self,
base_seed: int,
comm_size: int | None = None,
comm=MPI_COMM_WORLD,
num_detectors_per_rank: int | None = None,
):
self.base_seed = base_seed
self.comm_size = None
self.comm_size = comm.size
self.num_detectors_per_rank = None
# self.root_seq = SeedSequence(base_seed)
self.metadata = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"save_format_version": self.SAVE_FORMAT_VERSION,
}
self.hierarchy = {}
if comm_size is not None:
self.build_mpi_layer(comm_size)
if self.comm_size is not None:
self.build_mpi_layer(self.comm_size)
if num_detectors_per_rank is not None:
self.build_detector_layer(num_detectors_per_rank)

Expand Down
14 changes: 6 additions & 8 deletions litebird_sim/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def init_rng_hierarchy(self, random_seed):
the user-provided `random_seed`, but can be invoked again to
reseed the simulation at any point.
"""
self.rng_hierarchy = RNGHierarchy(random_seed)
self.rng_hierarchy = RNGHierarchy(random_seed, self.mpi_comm)
self.rng_hierarchy.build_mpi_layer(self.mpi_comm.size)

self.random = self.rng_hierarchy.get_generator(self.mpi_comm.rank)
Expand Down Expand Up @@ -2103,14 +2103,13 @@ def apply_quadratic_nonlin(

This is a wrapper around
:func:`.apply_quadratic_nonlin_to_observations()` that
applies non-linearity to a list of :class:`.Observation` instance. Random number generators are obtained from the detector-level layer. As default it uses
the `dets_random` field of a :class:`.Simulation` object for this.
applies non-linearity to a list of :class:`.Observation` instance. Random number generators are obtained for each detector, independently of the MPI distribution across detectors and time. Setting the `user_seed` argument is required for this.
"""
if rng_hierarchy is None:
rng_hierarchy = self.rng_hierarchy
dets_random = rng_hierarchy.get_detector_level_generators_on_rank(
self.mpi_comm.rank

assert user_seed is not None, (
"user_seed must be given in apply_quadratic_nonlin."
)

if nl_params is None:
nl_params = NonLinParams()

Expand All @@ -2119,7 +2118,6 @@ def apply_quadratic_nonlin(
nl_params=nl_params,
user_seed=user_seed,
component=component,
dets_random=dets_random,
)

if append_to_report and MPI_COMM_WORLD.rank == 0:
Expand Down
2 changes: 1 addition & 1 deletion test/test_gaindrifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def get_RNG_hierarchy(seed, num_of_dets):
rng_hierarchy = lbs.RNGHierarchy(
seed, comm_size=1, num_detectors_per_rank=num_of_dets
seed, comm=lbs.MPI_COMM_WORLD, num_detectors_per_rank=num_of_dets
)
return rng_hierarchy

Expand Down
60 changes: 58 additions & 2 deletions test/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from typing import Callable

import astropy.time as astrotime
import numpy as np

import litebird_sim as lbs
import numpy as np
from litebird_sim import MPI_COMM_WORLD


Expand Down Expand Up @@ -677,6 +676,62 @@ def test_nullify_mpi(tmp_path):
sim.flush()


def test_non_linearity_seeding():
"""Check if the seed for each detector is consistent even when
different MPI tasks share the same detector in different time samples
"""

if lbs.MPI_COMM_WORLD.size < 2:
return

rank = lbs.MPI_COMM_WORLD.rank

start_time = 0
time_span_s = 4
sampling_hz = 1

sim = lbs.Simulation(
start_time=start_time,
duration_s=time_span_s,
random_seed=12345,
imo=lbs.Imo(flatfile_location=lbs.PTEP_IMO_LOCATION),
)

det1 = lbs.DetectorInfo(
name="Dummy detector",
sampling_rate_hz=sampling_hz,
bandcenter_ghz=100.0,
quat=[0.0, 0.0, 0.0, 1.0],
)

det2 = lbs.DetectorInfo(
name="Dummy detector",
sampling_rate_hz=sampling_hz,
bandcenter_ghz=100.0,
quat=[0.0, 0.0, 0.0, 1.0],
)

sim.create_observations(
detectors=[det1, det2],
n_blocks_time=2,
n_blocks_det=1,
split_list_over_processes=False,
)

nl_params = lbs.NonLinParams(
sampling_gaussian_loc=0.0, sampling_gaussian_scale=2e-3
)

sim.observations[0].tod = np.ones_like(sim.observations[0].tod)

sim.apply_quadratic_nonlin(nl_params=nl_params, user_seed=1234)

tods = lbs.MPI_COMM_WORLD.gather(sim.observations[0].tod, root=0)

if rank == 0:
np.testing.assert_equal(tods[0], tods[1])


if __name__ == "__main__":
test_functions = [
test_observation_time,
Expand All @@ -686,6 +741,7 @@ def test_nullify_mpi(tmp_path):
test_observation_tod_two_block_det,
test_observation_tod_set_blocks,
test_simulation_random,
test_non_linearity_seeding,
]

for cur_test_fn in test_functions:
Expand Down
Loading
Loading