diff --git a/config/config_diffusion_d2048_ERA5.yml b/config/config_diffusion_d2048_ERA5.yml new file mode 100644 index 000000000..e00c78b0c --- /dev/null +++ b/config/config_diffusion_d2048_ERA5.yml @@ -0,0 +1,345 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: None # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: None # options: "cross_attn", "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: "" +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_diffusion_d2048_date_time.yml b/config/config_diffusion_d2048_date_time.yml new file mode 100644 index 000000000..062327899 --- /dev/null +++ b/config/config_diffusion_d2048_date_time.yml @@ -0,0 +1,346 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: "date_time" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "ada_ln" # options: "cross_attn", "ada_ln" +diffusion_conditioning_embed_dim: 32 # only used if fe_diffusion_model_conditioning_type is "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: "" +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_diffusion_d2048.yml b/config/config_diffusion_d2048_forecast.yml similarity index 100% rename from config/config_diffusion_d2048.yml rename to config/config_diffusion_d2048_forecast.yml diff --git a/config/config_diffusion_d2048_time.yml b/config/config_diffusion_d2048_time.yml new file mode 100644 index 000000000..babf253b3 --- /dev/null +++ b/config/config_diffusion_d2048_time.yml @@ -0,0 +1,346 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diffusion_model: True +fe_diffusion_model_conditioning: "time" # options: "date_time", "time", "forecast" +fe_diffusion_model_conditioning_type: "ada_ln" # options: "cross_attn", "ada_ln" +diffusion_conditioning_embed_dim: 32 # only used if fe_diffusion_model_conditioning_type is "ada_ln" +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False +# Diffusion related parameters +frequency_embedding_dim: 256 +embedding_dim: 512 +sigma_min: 0.002 +sigma_max: 80 +sigma_data: 1.0 +rho: 7 +p_mean: 1.5 +p_std: 1.2 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: True +mlp_type: swiglu +use_xsa: True +# mlp_type: mlp +# use_xsa: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*pred_heads.*|.*target_token_engines.*|.*embed_target_coords.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*fe.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +freeze_modules: ".*latent_pre_norm.*|.*latent_heads.*|.*encoder.*|.*StreamEmbedder_ERA5.*|.*embed_engine.*|.*embed_engine.*|.*ae_local_engine.*|.*ae_local_global_engine.*|.*ae_global_engine.*" +# freeze_modules: "" +# load_chkpt: {'run_id': 't0bdz7qn', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.7 +# load_chkpt: {'run_id': 'dcl584vo', 'epoch': -1} # z500 d2048 hl5, sigma_data=159.08 +# load_chkpt: {'run_id': 'tvkicam9', 'epoch': -1} # z500 d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'q9grso75', 'epoch': -1} # z500 d2048 hl3, sigma_data=39.2936 +# load_chkpt: {'run_id': 'qxivdyqz', 'epoch': -1} # z500 d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'h8x1qgz3', 'epoch': -1} # z500 d128 hl5, sigma_data=12.93 +# load_chkpt: {'run_id': '', 'epoch': -1} # z500 d128 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'wvpb76ai', 'epoch': -1} # multi-var d2048 hl3 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'ae4wlc5m', 'epoch': -1} # multi-var d2048 hl3, sigma_data=2.7047 +# load_chkpt: {'run_id': 'r45iwyns', 'epoch': -1} # multi-var d512 hl3, sigma_data=1.1785 +# load_chkpt: {'run_id': 'ydka6uql', 'epoch': -1} # multi-var d512 hl4, sigma_data=0.827 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5, sigma_data=0.5789 +# load_chkpt: {'run_id': 'v8kd6xc1', 'epoch': -1} # multi-var d512 hl5 nopos, sigma_data=0.6481 +# load_chkpt: {'run_id': 'lwjkb3y4', 'epoch': -1} # multi-var d512 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'y1gu5md8', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, diffusion-full-pipeline +# load_chkpt: {'run_id': 'mal6u4gc', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 64 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'zrpncqb0', 'epoch': -1} # multi-var d512 hl5, sigma_dqta=1.0, geoinfos 196 epochs, diffusion-full-pipeline +# load_chkpt: {'run_id': 'm6fs8wvj', 'epoch': -1} # multi-var d512 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'cgxt9imf', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=0.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'wo5mf2z4', 'epoch': -1} # diffusion model d512 to fine-tune decoder, p_mean=1.5, SwiGLU+XSA+geoinfos, based on m6fs8wvj backbone +# load_chkpt: {'run_id': 'zf6wnmpe', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.832 +# load_chkpt: {'run_id': 'mivw6jda', 'epoch': -1} # multi-var d2048 hl5 enc-lnorm, sigma_data=1.0 +# load_chkpt: {'run_id': 'l3rxe29i', 'epoch': -1} # multi-var d2048 hl5, sigma_data=1.0, swiglu xsa geoinfos, diffusion-full-pipeline +# load_chkpt: {'run_id': 'riyz96d4', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=0.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'bokn5d2w', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=1.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'uwyv1zdh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'dvslhdp3', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=2.5, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'mtlfdgvh', 'epoch': -1} # diffusion model d2048 to fine-tune decoder, p_mean=3.0, based on l3rxe29i backbone +# load_chkpt: {'run_id': 'p6cxplfd', 'epoch': -1} # d2048 hl5, p_mean=-2.5, forecasting +# load_chkpt: {'run_id': 'c2y0srpq', 'epoch': -1} # d2048 hl5, p_mean=-2.0, forecasting +# load_chkpt: {'run_id': 'c0zk6oli', 'epoch': -1} # d2048 hl5, p_mean=-1.5, forecasting +# load_chkpt: {'run_id': 'wt2leaf4', 'epoch': -1} # d2048 hl5, p_mean=-1.2, forecasting +# load_chkpt: {'run_id': 'iq5p8ujf', 'epoch': -1} # d2048 hl5, p_mean=-0.5, forecasting +# load_chkpt: {'run_id': 'u2r8b4z6', 'epoch': -1} # d2048 hl5, p_mean=0.5, forecasting +load_chkpt: {'run_id': 'ug7huxi2', 'epoch': -1} # d2048 hl5, p_mean=1.5, forecasting +# load_chkpt: {'run_id': 'i3y5fhda', 'epoch': -1} # d2048 hl5, p_mean=2.5, forecasting +# load_chkpt: {'run_id': 'wd0u4he8', 'epoch': -1} # d2048 hl5, p_mean=3.5, forecasting + + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting_d2048/" +# streams_directory: "./config/streams/era5_1deg_forecasting_z500/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"] + + num_mini_epochs: 128 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T18:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 #5e-5 + lr_max: 1e-5 #1e-4 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 64 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + weight: 0.0, + loss_fcts: { + "mse": {}, + }, + target_and_aux_calc: "Physical", + }, + "latent_diff": { + type: LossLatentDiffusion, + weight: 1.0, + target_and_aux_calc: DiffusionLatentTargetEncoder, + loss_fcts: { "mse": { }, }, + } + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + target_input: { + "forecasting" : { + masking_strategy: "forecast", + masking_strategy_config: {diffusion_rn: True}, + num_steps_input: 1, + num_samples: 1, + } + } + + forecast : + time_step: 06:00:00 + num_steps: 1 + offset: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + # Noise levels (eta values in standard normal space) at which to evaluate the + # diffusion model during validation. sigma = exp(eta * p_std + p_mean). + # Each value produces a separate validation pass with independently logged metrics. + validation_noise_levels: [1.0, 2.0, 3.0, 4.0] + + samples_per_mini_epoch: 256 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T18:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: True + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 0f75910b1..7909534d1 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -59,6 +59,11 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"fe_diffusion_model_conditioning is '{self.conditioning}' " f"(got '{self.conditioning_type}')" ) + _ada_ln = self.conditioning_type == "ada_ln" + assert self.cf.get("diffusion_conditioning_embed_dim", None) is not None or not _ada_ln, ( + f"diffusion_conditioning_embed_dim must be set when " + f"fe_diffusion_model_conditioning_type is 'ada_ln'" + ) _offset = self.cf.get("training_config", {}).get("forecast", {}).get("offset", 0) assert self.conditioning not in _date_time_modes or _offset == 0, ( f"forecast.offset must be 0 when fe_diffusion_model_conditioning is " @@ -69,6 +74,10 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"forecast.input_num_steps must be 2 when fe_diffusion_model_conditioning is " f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" ) + assert self.conditioning not in ["date_time", "date", "time"] or _input_num_steps == 1, ( + f"forecast.input_num_steps must be 1 when fe_diffusion_model_conditioning is " + f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" + ) assert self.conditioning != "forecast" or self.conditioning_type in {"cross_attn"}, ( f"fe_diffusion_model_conditioning_type must be 'cross_attn' when " f"fe_diffusion_model_conditioning is 'forecast' " @@ -191,9 +200,9 @@ def training_forward( ) c = None - if self.cf.get("fe_diffusion_model_conditioning", None) in ["date_time", "date", "time"]: + if self.conditioning in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] - elif self.cf.get("fe_diffusion_model_conditioning", None) == "forecast": + elif self.conditioning == "forecast": c = meta_info["ERA5"].params["conditioning_tokens"] # X_{t-1} as conditioning (model.py extracts last step as target, passes second-to-last here) if self.training: @@ -234,7 +243,7 @@ def denoise( noise_emb = self.noise_embedder(c_noise) # Precondition input and feed through network - if self.cf.get("fe_diffusion_model_conditioning", None) in ["date_time", "date", "time"]: + if self.conditioning in ["date_time", "date", "time"]: c = self.datetime_embedder(c).to(x.device) net_input = c_in * x @@ -268,11 +277,10 @@ def inference_forward( # Extract conditioning (mirrors training_forward). c = None - if self.cf.get("fe_diffusion_model_conditioning", None) in ["date_time", "date", "time"]: + if self.conditioning in ["date_time", "date", "time"]: c = meta_info["ERA5"].params["timestamp"] - elif self.cf.get("fe_diffusion_model_conditioning", None) == "forecast": - # cur_token = enc(X_t) stored in forward() before routing to inference_forward - c = self.cur_token + elif self.conditioning == "forecast": + c = meta_info["ERA5"].params["conditioning_tokens"] # Sample pure noise (assuming single batch element for now) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index dcf2c0992..6c8e847db 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -589,7 +589,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), - is_dit=self.cf.fe_diffusion_model, + is_dit=self.cf.get("fe_diffusion_model", False), dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) @@ -610,12 +610,12 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.get("rope_2D", False), - is_dit=self.cf.fe_diffusion_model, + is_dit=self.cf.get("fe_diffusion_model", False), dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) # Add cross-attention block (Q=noised tokens, KV=enc(X_t)) for cross_attn conditioning - if self.cf.get("fe_diffusion_model_conditioning_type") == "cross_attn": + if self.cf.get("fe_diffusion_model_conditioning_type", None) == "cross_attn": self.fe_blocks.append( MultiCrossAttentionHead( dim_embed_q=self.cf.ae_global_dim_embed, @@ -629,7 +629,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - is_dit=self.cf.fe_diffusion_model, + is_dit=self.cf.get("fe_diffusion_model", False), ) ) # Add MLP block @@ -644,7 +644,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = norm_type=self.cf.norm_type, dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, - is_dit=self.cf.fe_diffusion_model, + is_dit=self.cf.get("fe_diffusion_model", False), dit_is_cond=self.cf.get("fe_diffusion_model_conditioning_type", None) == "ada_ln", ) ) @@ -685,7 +685,7 @@ def forward( if forecast_residual: tokens_in = tokens - if self.cf.fe_diffusion_model: + if self.cf.get("fe_diffusion_model", False): assert noise_emb is not None, ( "noise_emb must be provided for diffusion model conditioning" )