From d28a2e4a0be54ed4e8421ac63c44922021228a3a Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Mon, 18 May 2026 18:36:17 +0200 Subject: [PATCH 1/3] alpha earth embbedings reader and configs --- config/config_forecasting_synop.yml | 250 ++++++++++++++++++ .../config_forecasting_synop_alphaearth.yml | 250 ++++++++++++++++++ .../era5.yml | 4 +- .../synop.yml | 4 +- .../era5_forecast_synop_alphaearth/era5.yml | 37 +++ .../era5_forecast_synop_alphaearth/synop.yml | 39 +++ .../datasets/data_reader_alphaearth.py | 247 +++++++++++++++++ 7 files changed, 826 insertions(+), 5 deletions(-) create mode 100644 config/config_forecasting_synop.yml create mode 100644 config/config_forecasting_synop_alphaearth.yml rename config/streams/{era5_decoding_synop => era5_forecast_synop}/era5.yml (96%) rename config/streams/{era5_decoding_synop => era5_forecast_synop}/synop.yml (89%) create mode 100644 config/streams/era5_forecast_synop_alphaearth/era5.yml create mode 100644 config/streams/era5_forecast_synop_alphaearth/synop.yml create mode 100644 src/weathergen/datasets/data_reader_alphaearth.py diff --git a/config/config_forecasting_synop.yml b/config/config_forecasting_synop.yml new file mode 100644 index 000000000..c402c380a --- /dev/null +++ b/config/config_forecasting_synop.yml @@ -0,0 +1,250 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_forecast_synop/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "surface_fc_no_aef" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 32 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 2017-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 2 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-01-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 600 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: aef_synop_forecast + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: no_alphaearth diff --git a/config/config_forecasting_synop_alphaearth.yml b/config/config_forecasting_synop_alphaearth.yml new file mode 100644 index 000000000..f0a3ee5b5 --- /dev/null +++ b/config/config_forecasting_synop_alphaearth.yml @@ -0,0 +1,250 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_forecast_synop_alphaearth/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "surface_fc_aef" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 32 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 2017-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 2 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-01-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 600 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: aef_synop_forecast + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: alphaearth diff --git a/config/streams/era5_decoding_synop/era5.yml b/config/streams/era5_forecast_synop/era5.yml similarity index 96% rename from config/streams/era5_decoding_synop/era5.yml rename to config/streams/era5_forecast_synop/era5.yml index c04f77a24..0b3e8bbc1 100644 --- a/config/streams/era5_decoding_synop/era5.yml +++ b/config/streams/era5_forecast_synop/era5.yml @@ -14,10 +14,10 @@ ERA5 : loss_weight : 1. source_exclude : ['w_', 'skt', 'sp', 'tcw', 'cp', 'tp'] target: [] - diagnostic : False + forcing : True masking_rate : 0.6 masking_rate_none : 0.05 - token_size : 32 + token_size : 8 tokenize_spacetime : True max_num_targets: -1 embed : diff --git a/config/streams/era5_decoding_synop/synop.yml b/config/streams/era5_forecast_synop/synop.yml similarity index 89% rename from config/streams/era5_decoding_synop/synop.yml rename to config/streams/era5_forecast_synop/synop.yml index 60ab8134a..3504c2feb 100644 --- a/config/streams/era5_decoding_synop/synop.yml +++ b/config/streams/era5_forecast_synop/synop.yml @@ -7,11 +7,9 @@ SurfaceCombined : type : obs stream_id : 2 # source: [] - is_diagnostic: True + diagnostic: True filenames : ['observations-ea-ofb-0001-1979-2023-combined-surface-v2.zarr'] loss_weight : 1.0 - masking_rate : 0.6 - masking_rate_none : 0.05 token_size : 64 tokenize_spacetime : True max_num_targets: -1 diff --git a/config/streams/era5_forecast_synop_alphaearth/era5.yml b/config/streams/era5_forecast_synop_alphaearth/era5.yml new file mode 100644 index 000000000..baa413a9b --- /dev/null +++ b/config/streams/era5_forecast_synop_alphaearth/era5.yml @@ -0,0 +1,37 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + stream_id : 0 + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + loss_weight : 1. + source_exclude : ['w_', 'skt', 'sp', 'tcw', 'cp', 'tp'] + target: [] + forcing : True + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 4 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 4 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/streams/era5_forecast_synop_alphaearth/synop.yml b/config/streams/era5_forecast_synop_alphaearth/synop.yml new file mode 100644 index 000000000..4bac047a2 --- /dev/null +++ b/config/streams/era5_forecast_synop_alphaearth/synop.yml @@ -0,0 +1,39 @@ +# obs_types +# 0 : polar orbiting satellites +# 1 : geostationay satellites +# 2 : conventional observations + +SurfaceCombined : + type : obs + stream_id : 2 + # source: [] + diagnostic: True + filenames : ['observations-ea-ofb-0001-1979-2023-combined-surface-v2.zarr'] + geoinfo_sources : + - type: alphaearth + filename: /e/data1/slmet/ml_training/aef_synops_stations_9x9_mod2.zarr + patch_mode: center + # Only attach AlphaEarth features for close coordinate matches; unmatched rows get the + # AlphaEarth mean, which normalizes to neutral zeros in the decoder geoinfo vector. + max_distance_deg: 0.01 + stats_sample_size: 2048 + missing_value: mean + loss_weight : 1.0 + token_size : 64 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 128 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 128 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/src/weathergen/datasets/data_reader_alphaearth.py b/src/weathergen/datasets/data_reader_alphaearth.py new file mode 100644 index 000000000..90ca1436d --- /dev/null +++ b/src/weathergen/datasets/data_reader_alphaearth.py @@ -0,0 +1,247 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from collections import defaultdict +from pathlib import Path + +import numpy as np +import zarr +from numpy.typing import NDArray + +_logger = logging.getLogger(__name__) + + +class DataReaderAlphaEarthGeoinfo: + """Read AlphaEarth embeddings and expose them as per-observation geoinfo features.""" + + def __init__(self, filename: Path, config: dict) -> None: + self.filename = filename + self.config = config + self.z = zarr.open(filename, mode="r") + self.data = self.z["data"] + self.dates = np.asarray(self.z["dates"][:]).astype("datetime64[ns]") + + self.patch_mode = str(config.get("patch_mode", "center")).lower() + self.max_distance_deg = float(config.get("max_distance_deg", 0.05)) + self.lookup_cell_size_deg = float( + config.get("lookup_cell_size_deg", self.max_distance_deg) + ) + self.stats_sample_size = int(config.get("stats_sample_size", 2048)) + self.missing_value = config.get("missing_value", "mean") + self.prefix = str(config.get("prefix", "alphaearth")) + + if self.max_distance_deg <= 0.0: + raise ValueError("AlphaEarth geoinfo max_distance_deg must be greater than zero") + if self.lookup_cell_size_deg <= 0.0: + raise ValueError("AlphaEarth geoinfo lookup_cell_size_deg must be greater than zero") + + if len(self.data.shape) != 5: + raise ValueError( + "AlphaEarth geoinfo data must have shape " + "(station, date, channel, y, x), got " + f"{self.data.shape}" + ) + + metadata = self.z["metadata"][:] + self.station_coords = np.column_stack( + [metadata["lat"], self._normalize_longitudes(metadata["lon"])] + ).astype(np.float32) + + self.num_channels = int(self.data.shape[2]) + self.patch_y = int(self.data.shape[3]) + self.patch_x = int(self.data.shape[4]) + center = config.get("patch_center", [self.patch_y // 2, self.patch_x // 2]) + self.center_y = int(center[0]) + self.center_x = int(center[1]) + + self.channel_names = self._build_channel_names() + self.feature_size = len(self.channel_names) + self._station_bins = self._build_station_bins() + self._coord_cache: dict[tuple[float, float], int] = {} + self.mean, self.stdev = self._compute_stats() + + _logger.info( + "Loaded AlphaEarth geoinfos from %s with %s features, patch_mode=%s, " + "max_distance_deg=%s", + filename, + self.feature_size, + self.patch_mode, + self.max_distance_deg, + ) + + @staticmethod + def _normalize_longitudes(longitudes: NDArray[np.floating]) -> NDArray[np.float32]: + return (((longitudes + 180.0) % 360.0) - 180.0).astype(np.float32) + + def _build_channel_names(self) -> list[str]: + if self.patch_mode in ("center", "mean"): + return [f"{self.prefix}_{channel_idx:03d}" for channel_idx in range(self.num_channels)] + if self.patch_mode == "flatten": + return [ + f"{self.prefix}_{channel_idx:03d}_y{y_idx}_x{x_idx}" + for channel_idx in range(self.num_channels) + for y_idx in range(self.patch_y) + for x_idx in range(self.patch_x) + ] + raise ValueError( + f"Unknown AlphaEarth geoinfo patch_mode {self.patch_mode}. " + "Expected one of: center, mean, flatten." + ) + + def _lat_bin(self, lat: float) -> int: + return int(np.floor((lat + 90.0) / self.lookup_cell_size_deg)) + + def _lon_bin(self, lon: float) -> int: + return int(np.floor((lon + 180.0) / self.lookup_cell_size_deg)) % self._num_lon_bins + + @property + def _num_lon_bins(self) -> int: + return int(np.ceil(360.0 / self.lookup_cell_size_deg)) + + @property + def _num_lat_bins(self) -> int: + return int(np.ceil(180.0 / self.lookup_cell_size_deg)) + 1 + + def _build_station_bins(self) -> dict[tuple[int, int], list[int]]: + station_bins: dict[tuple[int, int], list[int]] = defaultdict(list) + for station_idx, (lat, lon) in enumerate(self.station_coords): + if np.isnan(lat) or np.isnan(lon): + continue + station_bins[(self._lat_bin(float(lat)), self._lon_bin(float(lon)))].append( + station_idx + ) + return dict(station_bins) + + def _compute_stats(self) -> tuple[NDArray[np.float32], NDArray[np.float32]]: + if not bool(self.config.get("normalize", True)) or self.stats_sample_size <= 0: + return ( + np.zeros(self.feature_size, dtype=np.float32), + np.ones(self.feature_size, dtype=np.float32), + ) + + num_stations = int(self.data.shape[0]) + sample_size = min(self.stats_sample_size, num_stations) + station_indices = np.linspace(0, num_stations - 1, sample_size, dtype=np.int64) + sample_features = [ + self._read_features(station_indices, date_idx) + for date_idx in range(int(self.data.shape[1])) + ] + sample = np.concatenate(sample_features, axis=0) + mean = np.mean(sample, axis=0, dtype=np.float64).astype(np.float32) + stdev = np.std(sample, axis=0, dtype=np.float64).astype(np.float32) + stdev[np.isclose(stdev, 0.0)] = 1.0 + return mean, stdev + + def _read_features( + self, station_indices: NDArray[np.int64], date_idx: int + ) -> NDArray[np.float32]: + if self.patch_mode == "center": + features = self.data.oindex[ + station_indices, date_idx, slice(None), self.center_y, self.center_x + ] + elif self.patch_mode == "mean": + patch = self.data.oindex[station_indices, date_idx, :, :, :].astype(np.float32) + features = patch.mean(axis=(-1, -2)) + elif self.patch_mode == "flatten": + patch = self.data.oindex[station_indices, date_idx, :, :, :] + features = patch.reshape((len(station_indices), self.feature_size)) + else: + raise ValueError(f"Unknown AlphaEarth geoinfo patch_mode {self.patch_mode}") + + return np.asarray(features, dtype=np.float32) + + def _date_indices(self, datetimes: NDArray[np.datetime64]) -> NDArray[np.int64]: + datetimes = np.asarray(datetimes).astype("datetime64[ns]") + right = np.searchsorted(self.dates, datetimes, side="left") + right = np.clip(right, 0, len(self.dates) - 1) + left = np.clip(right - 1, 0, len(self.dates) - 1) + + left_delta = np.abs(datetimes - self.dates[left]) + right_delta = np.abs(datetimes - self.dates[right]) + return np.where(right_delta < left_delta, right, left).astype(np.int64) + + def _station_indices(self, coords: NDArray[np.float32]) -> NDArray[np.int64]: + station_indices = np.full(coords.shape[0], -1, dtype=np.int64) + search_radius = int(np.ceil(self.max_distance_deg / self.lookup_cell_size_deg)) + + coords = np.asarray(coords, dtype=np.float32) + lats = coords[:, 0] + lons = self._normalize_longitudes(coords[:, 1]) + + for coord_idx, (lat, lon) in enumerate(zip(lats, lons, strict=True)): + if np.isnan(lat) or np.isnan(lon): + continue + + cache_key = (float(lat), float(lon)) + cached_station_idx = self._coord_cache.get(cache_key) + if cached_station_idx is not None: + station_indices[coord_idx] = cached_station_idx + continue + + lat_bin = self._lat_bin(float(lat)) + lon_bin = self._lon_bin(float(lon)) + candidate_indices = [] + for lat_offset in range(-search_radius, search_radius + 1): + candidate_lat_bin = lat_bin + lat_offset + if candidate_lat_bin < 0 or candidate_lat_bin >= self._num_lat_bins: + continue + for lon_offset in range(-search_radius, search_radius + 1): + candidate_lon_bin = (lon_bin + lon_offset) % self._num_lon_bins + candidate_indices.extend( + self._station_bins.get((candidate_lat_bin, candidate_lon_bin), []) + ) + + if not candidate_indices: + self._coord_cache[cache_key] = -1 + continue + + candidates = self.station_coords[np.asarray(candidate_indices, dtype=np.int64)] + dlat = candidates[:, 0] - lat + dlon = np.abs(candidates[:, 1] - lon) + dlon = np.minimum(dlon, 360.0 - dlon) + distances = np.sqrt(dlat * dlat + dlon * dlon) + nearest = int(np.argmin(distances)) + if distances[nearest] <= self.max_distance_deg: + station_indices[coord_idx] = candidate_indices[nearest] + self._coord_cache[cache_key] = int(station_indices[coord_idx]) + + return station_indices + + def _missing_features(self, num_rows: int) -> NDArray[np.float32]: + if isinstance(self.missing_value, str): + if self.missing_value == "mean": + return np.broadcast_to(self.mean, (num_rows, self.feature_size)).copy() + if self.missing_value in ("zero", "zeros"): + return np.zeros((num_rows, self.feature_size), dtype=np.float32) + raise ValueError( + f"Unknown AlphaEarth geoinfo missing_value {self.missing_value}. " + "Expected 'mean', 'zero', or a numeric fill value." + ) + + return np.full((num_rows, self.feature_size), float(self.missing_value), dtype=np.float32) + + def get( + self, coords: NDArray[np.float32], datetimes: NDArray[np.datetime64] + ) -> NDArray[np.float32]: + features = self._missing_features(coords.shape[0]) + if coords.shape[0] == 0: + return features + + station_indices = self._station_indices(coords) + valid_station_mask = station_indices >= 0 + if not valid_station_mask.any(): + return features + + date_indices = self._date_indices(datetimes) + for date_idx in np.unique(date_indices[valid_station_mask]): + row_mask = valid_station_mask & (date_indices == date_idx) + features[row_mask] = self._read_features(station_indices[row_mask], int(date_idx)) + + return features From cd4950fd939be762376fec34e0cbdb7bc26b033c Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Mon, 18 May 2026 18:36:43 +0200 Subject: [PATCH 2/3] changes to obs datareader --- src/weathergen/datasets/data_reader_obs.py | 72 ++++++++++++++++++---- tests/test_data_reader_alphaearth.py | 60 ++++++++++++++++++ 2 files changed, 120 insertions(+), 12 deletions(-) create mode 100644 tests/test_data_reader_alphaearth.py diff --git a/src/weathergen/datasets/data_reader_obs.py b/src/weathergen/datasets/data_reader_obs.py index 896638979..91f78ae94 100644 --- a/src/weathergen/datasets/data_reader_obs.py +++ b/src/weathergen/datasets/data_reader_obs.py @@ -14,7 +14,9 @@ import numpy as np import zarr +from numpy.typing import NDArray +from weathergen.datasets.data_reader_alphaearth import DataReaderAlphaEarthGeoinfo from weathergen.datasets.data_reader_base import ( DataReaderBase, ReaderData, @@ -75,15 +77,27 @@ def __init__( # determine idx for coords and geoinfos self.coords_idx = [self.colnames.index("lat"), self.colnames.index("lon")] - self.geoinfo_idx = list(range(self.coords_idx[-1] + 1, data_idx[0])) - self.geoinfo_channels = [self.colnames[i] for i in self.geoinfo_idx] + self.native_geoinfo_idx = list(range(self.coords_idx[-1] + 1, data_idx[0])) + native_geoinfo_channels = [self.colnames[i] for i in self.native_geoinfo_idx] + self.external_geoinfo_readers = self._init_external_geoinfo_readers(stream_info) + external_geoinfo_channels = [ + channel + for reader in self.external_geoinfo_readers + for channel in reader.channel_names + ] + self.geoinfo_channels = native_geoinfo_channels + external_geoinfo_channels + self.geoinfo_idx = list(range(len(self.geoinfo_channels))) # load additional properties (mean, var) self._load_properties() self.mean = np.array(self.properties["means"]) # [data_idx] self.stdev = np.sqrt(np.array(self.properties["vars"])) # [data_idx]) - self.mean_geoinfo = np.array(self.properties["means"])[self.geoinfo_idx] - self.stdev_geoinfo = np.sqrt(np.array(self.properties["vars"])[self.geoinfo_idx]) + native_mean_geoinfo = np.array(self.properties["means"])[self.native_geoinfo_idx] + native_stdev_geoinfo = np.sqrt(np.array(self.properties["vars"])[self.native_geoinfo_idx]) + external_mean_geoinfo = [reader.mean for reader in self.external_geoinfo_readers] + external_stdev_geoinfo = [reader.stdev for reader in self.external_geoinfo_readers] + self.mean_geoinfo = np.concatenate([native_mean_geoinfo, *external_mean_geoinfo]) + self.stdev_geoinfo = np.concatenate([native_stdev_geoinfo, *external_stdev_geoinfo]) # Create index for samples self._setup_sample_index() @@ -215,6 +229,40 @@ def _load_properties(self) -> None: self.properties["means"] = self.data.attrs["means"] self.properties["vars"] = self.data.attrs["vars"] + def _init_external_geoinfo_readers( + self, stream_info: dict + ) -> list[DataReaderAlphaEarthGeoinfo]: + readers = [] + for geoinfo_source in stream_info.get("geoinfo_sources", []): + source_type = str(geoinfo_source.get("type", "")).lower() + if source_type != "alphaearth": + raise ValueError( + "Unknown geoinfo source type " + f"{source_type} for stream {stream_info.get('name', '')}" + ) + + geoinfo_filename = Path(geoinfo_source["filename"]) + if not geoinfo_filename.is_absolute(): + geoinfo_filename = self.filename.parent / geoinfo_filename + readers.append(DataReaderAlphaEarthGeoinfo(geoinfo_filename, geoinfo_source)) + + return readers + + def _append_external_geoinfos( + self, + geoinfos: NDArray, + coords: NDArray, + datetimes: NDArray, + ) -> NDArray: + if not self.external_geoinfo_readers: + return geoinfos + + external_geoinfos = [ + reader.get(coords.astype(np.float32, copy=False), datetimes) + for reader in self.external_geoinfo_readers + ] + return np.concatenate([geoinfos, *external_geoinfos], axis=1) + @override def _get(self, idx: int, channels_idx: list[int]) -> ReaderData: """ @@ -242,8 +290,8 @@ def _get(self, idx: int, channels_idx: list[int]) -> ReaderData: coords = self.data.oindex[start_row:end_row, self.coords_idx] geoinfos = ( - self.data.oindex[start_row:end_row, self.geoinfo_idx] - if len(self.geoinfo_idx) > 0 + self.data.oindex[start_row:end_row, self.native_geoinfo_idx] + if len(self.native_geoinfo_idx) > 0 else np.zeros((coords.shape[0], 0), np.float32) ) @@ -256,12 +304,12 @@ def _get(self, idx: int, channels_idx: list[int]) -> ReaderData: t_win = self.time_window_handler.window(idx) t_mask = np.logical_and(datetimes >= t_win.start, datetimes < t_win.end) - rdata = ReaderData( - coords=coords[t_mask], - geoinfos=geoinfos[t_mask], - data=data[t_mask], - datetimes=datetimes[t_mask], - ) + coords = coords[t_mask] + geoinfos = self._append_external_geoinfos(geoinfos[t_mask], coords, datetimes[t_mask]) + data = data[t_mask] + datetimes = datetimes[t_mask] + + rdata = ReaderData(coords=coords, geoinfos=geoinfos, data=data, datetimes=datetimes) dtr = self.time_window_handler.window(idx) check_reader_data(rdata, dtr) diff --git a/tests/test_data_reader_alphaearth.py b/tests/test_data_reader_alphaearth.py new file mode 100644 index 000000000..2f30fe7fb --- /dev/null +++ b/tests/test_data_reader_alphaearth.py @@ -0,0 +1,60 @@ +from pathlib import Path + +import numpy as np +import zarr + +from weathergen.datasets.data_reader_alphaearth import DataReaderAlphaEarthGeoinfo + + +def _create_alphaearth_zarr(path: Path) -> None: + root = zarr.open(path, mode="w") + dates = np.array(["2020-01-01", "2021-01-01"], dtype="datetime64[ns]") + metadata = np.array( + [ + (10.0, 20.0, 0.0, 0.0, 1.0, 1.0, 4326), + (-30.0, 170.0, 0.0, 0.0, 1.0, 1.0, 4326), + ], + dtype=[ + ("lat", "f4"), + ("lon", "f4"), + ("bbox_west", "f4"), + ("bbox_south", "f4"), + ("bbox_east", "f4"), + ("bbox_north", "f4"), + ("crs_code", "i4"), + ], + ) + data = np.zeros((2, 2, 3, 3, 3), dtype=np.int8) + data[0, 0, :, 1, 1] = [1, 2, 3] + data[0, 1, :, 1, 1] = [4, 5, 6] + data[1, 1, :, 1, 1] = [7, 8, 9] + + root.create_array("dates", data=dates) + root.create_array("metadata", data=metadata) + root.create_array("station_id", data=np.array(["station_0", "station_1"])) + root.create_array("data", data=data) + + +def test_alphaearth_geoinfo_reader_matches_station_and_date(tmp_path: Path) -> None: + alphaearth_path = tmp_path / "alphaearth.zarr" + _create_alphaearth_zarr(alphaearth_path) + + reader = DataReaderAlphaEarthGeoinfo( + alphaearth_path, + { + "patch_mode": "center", + "max_distance_deg": 0.1, + "stats_sample_size": 0, + }, + ) + + coords = np.array([[10.02, 20.01], [10.02, 20.01], [0.0, 0.0]], dtype=np.float32) + datetimes = np.array( + ["2020-02-01", "2021-02-01", "2021-02-01"], dtype="datetime64[ns]" + ) + + features = reader.get(coords, datetimes) + + np.testing.assert_array_equal(features[0], np.array([1, 2, 3], dtype=np.float32)) + np.testing.assert_array_equal(features[1], np.array([4, 5, 6], dtype=np.float32)) + np.testing.assert_array_equal(features[2], np.zeros(3, dtype=np.float32)) From b4297d1fe0979056fd38e555e88632b7a36fafcc Mon Sep 17 00:00:00 2001 From: Sebastian Hickman Date: Tue, 19 May 2026 11:50:24 +0200 Subject: [PATCH 3/3] add configs for alphaeath embeds --- ...onfig_forecasting_synop_alphaearth_005.yml | 250 ++++++++++++++++++ .../era5.yml | 37 +++ .../synop.yml | 39 +++ 3 files changed, 326 insertions(+) create mode 100644 config/config_forecasting_synop_alphaearth_005.yml create mode 100644 config/streams/era5_forecast_synop_alphaearth_005/era5.yml create mode 100644 config/streams/era5_forecast_synop_alphaearth_005/synop.yml diff --git a/config/config_forecasting_synop_alphaearth_005.yml b/config/config_forecasting_synop_alphaearth_005.yml new file mode 100644 index 000000000..1feb8b834 --- /dev/null +++ b/config/config_forecasting_synop_alphaearth_005.yml @@ -0,0 +1,250 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_forecast_synop_alphaearth_005/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "surface_fc_aef" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 32 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 2017-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 2 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-01-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 600 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: aef_synop_forecast + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: alphaearth diff --git a/config/streams/era5_forecast_synop_alphaearth_005/era5.yml b/config/streams/era5_forecast_synop_alphaearth_005/era5.yml new file mode 100644 index 000000000..baa413a9b --- /dev/null +++ b/config/streams/era5_forecast_synop_alphaearth_005/era5.yml @@ -0,0 +1,37 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + stream_id : 0 + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + loss_weight : 1. + source_exclude : ['w_', 'skt', 'sp', 'tcw', 'cp', 'tp'] + target: [] + forcing : True + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 4 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 4 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/streams/era5_forecast_synop_alphaearth_005/synop.yml b/config/streams/era5_forecast_synop_alphaearth_005/synop.yml new file mode 100644 index 000000000..28f65635e --- /dev/null +++ b/config/streams/era5_forecast_synop_alphaearth_005/synop.yml @@ -0,0 +1,39 @@ +# obs_types +# 0 : polar orbiting satellites +# 1 : geostationay satellites +# 2 : conventional observations + +SurfaceCombined : + type : obs + stream_id : 2 + # source: [] + diagnostic: True + filenames : ['observations-ea-ofb-0001-1979-2023-combined-surface-v2.zarr'] + geoinfo_sources : + - type: alphaearth + filename: /e/data1/slmet/ml_training/aef_synops_stations_9x9_mod2.zarr + patch_mode: center + # Only attach AlphaEarth features for close coordinate matches; unmatched rows get the + # AlphaEarth mean, which normalizes to neutral zeros in the decoder geoinfo vector. + max_distance_deg: 0.05 + stats_sample_size: 2048 + missing_value: mean + loss_weight : 1.0 + token_size : 64 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 128 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 128 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file