Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
366ae77
bug fix
moritzhauschulz Apr 22, 2026
dc6d82d
bug fix
moritzhauschulz Apr 22, 2026
aaeb073
config change
moritzhauschulz Apr 22, 2026
bce064d
plot config
moritzhauschulz Apr 23, 2026
77f6f43
plot config
moritzhauschulz May 4, 2026
140d1ca
update stage handling in diffusion
moritzhauschulz May 4, 2026
825f841
re-implement conditioning, update adalayernorm and embedding function
moritzhauschulz May 7, 2026
dd97830
remove debugging tool
moritzhauschulz May 7, 2026
77d5e0a
implement time only / day only conditioning
moritzhauschulz May 7, 2026
11c9e6a
date_time conditioning
moritzhauschulz May 13, 2026
b6ace25
activate swiglu, xsa
moritzhauschulz May 19, 2026
4dca300
initial commit with data flow for the forecast conditioning (conditio…
moritzhauschulz May 19, 2026
d2e7b5d
offset 1
moritzhauschulz May 19, 2026
fde1230
bug fix from merge
moritzhauschulz May 19, 2026
3f89769
change ada_ln argument passing
moritzhauschulz May 20, 2026
d74029d
naive implementation of conditioning via concatenation
moritzhauschulz May 20, 2026
8a1b698
remove CLAUDE.md
moritzhauschulz May 20, 2026
1844cbf
implemented cross-attn in fe engine
moritzhauschulz May 20, 2026
372bab4
removed concatenation option
moritzhauschulz May 20, 2026
13560fc
date in config
moritzhauschulz May 20, 2026
66ff754
comment in config
moritzhauschulz May 20, 2026
7369439
minor improvements
moritzhauschulz May 20, 2026
810987a
assert offset zero
moritzhauschulz May 20, 2026
58bf2d4
roll back data flow (not working)
moritzhauschulz May 21, 2026
9b652f4
cleanup rollback
moritzhauschulz May 21, 2026
27a3b1b
inter commit
moritzhauschulz May 21, 2026
8534dd2
fixes – forecast + cross_attn should run now
moritzhauschulz May 21, 2026
fb57fb9
apply PR review comments
moritzhauschulz May 21, 2026
5ec9446
additional check for num_input_steps
moritzhauschulz May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions config/config_diffusion.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,15 @@ fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_diffusion_model: True
fe_diffusion_model_conditioning: "forecast" # options: "date_time", "time", "forecast"
fe_diffusion_model_conditioning_type: "cross_attn" # options: "cross_attn", "ada_ln"
fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer
fe_impute_latent_noise_std: 0.0 # 1e-4
# currently fixed to 1.0 (due to limitations with flex_attention and triton)
forecast_att_dense_rate: 1.0
with_step_conditioning: True # False
# Diffusion related parameters
diffusion_conditioning_embed_dim: 32
frequency_embedding_dim: 256
embedding_dim: 512
sigma_min: 0.002
Expand All @@ -81,10 +84,10 @@ healpix_level: 5
# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon)
# When False: uses traditional pe_global positional encoding
rope_2D: True
# mlp_type: swiglu
# use_xsa: True
mlp_type: mlp
use_xsa: False
mlp_type: swiglu
use_xsa: True
# mlp_type: mlp
# use_xsa: False

with_mixed_precision: True
with_flash_attention: True
Expand Down Expand Up @@ -180,7 +183,7 @@ data_loading :
training_config:

# training_mode: "masking", "student_teacher", "latent_loss"
training_mode: ["masking","student_teacher"] # ["student_teacher", "physical_loss"]
training_mode: ["masking","student_teacher"]

num_mini_epochs: 128
samples_per_mini_epoch: 4096
Expand Down Expand Up @@ -236,7 +239,17 @@ training_config:
# masking strategy: "random", "healpix", "forecast"
masking_strategy: "forecast",
masking_strategy_config: {diffusion_rn: True},
num_samples: 1
num_steps_input: 2,
num_samples: 1,
}
}

target_input: {
"forecasting" : {
masking_strategy: "forecast",
masking_strategy_config: {diffusion_rn: True},
num_steps_input: 1,
num_samples: 1,
}
}

Expand Down
37 changes: 29 additions & 8 deletions config/runs_plot_train.yml
Original file line number Diff line number Diff line change
@@ -1,18 +1,39 @@

train:
plot:
r8vzykrm:
slurm_id: 387093
description: "bug-fix non-cond-branch, with sigma_data=0.63"
nol9pfdg:
slurm_id: 387094
description: "bug-fix cond-branch w/o cond, with sigma_data=0.63"
# imqzsbte:
# slurm_id: 387095
# description: "bug-fix non-cond"
f8nd1c60:
slurm_id: 387097
description: "bug-fix cond-branch w/ cond"
# ux8yjktb:
# slurm_id: 387095
# description: "bug-fix cond-branch w/ non-cond"
# xxkmgsne:
# slurm_id: 0
# description: "bug-fix cond-branch w/ non-cond, lr_max=5e-6"
# jwexz9y4:
# slurm_id: 0
# description: "bug-fix cond-branch w/ non-cond, lr_max=2.5e-6"
u7etjsm0:
slurm_id: 385058
description: "ERA5, lr_start=1e-6, lr_max=1e-5"
description: "old ERA5, lr_start=1e-6, lr_max=1e-5"
mot8sfay:
slurm_id: 385060
description: "ERA5, lr_start=1e-6, lr_max=7e-6"
zhon45xy:
slurm_id: 385064
description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=1e-5"
yimje7g3:
slurm_id: 385062
description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=7e-6"
description: "old ERA5, lr_start=1e-6, lr_max=7e-6"
# zhon45xy:
# slurm_id: 385064
# description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=1e-5"
# yimje7g3:
# slurm_id: 385062
# description: "conditioning w/ ERA5, lr_start=1e-6, lr_max=7e-6"
# bpeh160r:
# slurm_id: 381190
# description: "single samples, lr_start=1e-6, lr_max=1e-6"
Expand Down
26 changes: 25 additions & 1 deletion src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage):
self.streams = cf.streams
self.rank = cf.rank
self.world_size = cf.world_size
self.diffusion_model_conditioning = cf.fe_diffusion_model_conditioning
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cf.get("fe...", None)

self.repeat_data = cf.data_loading.get("repeat_data_in_mini_epoch", False)

# initialise healpic
Expand Down Expand Up @@ -499,7 +500,6 @@ def _build_stream_data(
output_mask : mask for output/prediction/target
input_mask : mask for network input (can be source or target)


Returns:
StreamData with source and targets masked according to view_meta
"""
Expand Down Expand Up @@ -722,6 +722,13 @@ def _get_batch(self, idx: int, num_forecast_steps: int):
)
target_metadata = target_masks.metadata[tidx]

# Get first target step's times (using self.output_offset as the first output step index)
if self.diffusion_model_conditioning in ["date_time", "date", "time"]:
target_times_array = sdata.target_times_raw[self.output_offset]
target_metadata.add_params({'timestamp': (
target_times_array[0] if len(target_times_array) > 0 else None
)})

# also want to add the mask to the metadata
target_metadata.mask = target_mask
# Map target to all source students
Expand All @@ -735,6 +742,23 @@ def _get_batch(self, idx: int, num_forecast_steps: int):
target_in_steps = 1 if len(target_in_steps) == 0 else target_in_steps.max().item()
batch = self._preprocess_model_batch(batch, source_in_steps, target_in_steps)

#add target times in source for diffusion model date/time conditioning
if self.diffusion_model_conditioning in ["date_time", "date", "time"]:
#TODO: Might need upgrading fro num_samples > 1

# Assert singular source and target samples
assert len(batch.source_samples.samples) == 1, "Only single source sample supported for diffusion model conditioning."
assert len(batch.target_samples.samples) == 1, "Only single target sample supported for diffusion model conditioning."

source_sample = batch.source_samples.samples[0]
target_sample = batch.target_samples.samples[0]

# Copy target timestamps to source metadata for all streams
for stream_name in [s["name"] for s in self.streams]:
if stream_name in target_sample.meta_info and stream_name in source_sample.meta_info:
target_timestamp = target_sample.meta_info[stream_name].params.get('timestamp')
source_sample.meta_info[stream_name].add_params({'timestamp': target_timestamp})

return batch

def __iter__(self) -> ModelBatch:
Expand Down
56 changes: 44 additions & 12 deletions src/weathergen/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

from weathergen.model.layers import LinearNormConditioning
from weathergen.model.norms import AdaLayerNorm, RMSNorm
from weathergen.model.norms import AdaLNZero, AdaLayerNorm, RMSNorm
from weathergen.model.positional_encoding import rotary_pos_emb_2d

"""
Expand Down Expand Up @@ -248,6 +248,7 @@ def __init__(
attention_dtype=torch.bfloat16,
with_2d_rope=False,
is_dit=False,
dit_is_cond=False,
use_xsa=False,
):
super(MultiSelfAttentionHeadLocal, self).__init__()
Expand All @@ -269,10 +270,13 @@ def __init__(
norm = RMSNorm

self.is_dit = is_dit
self.dit_is_cond = dit_is_cond
if is_dit:
if dit_is_cond:
assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm"
assert dim_aux is None, "conditioning not yet implemented for DIT attention"
assert with_residual, "DIT attention should always have residual connection"
self.lnorm = norm(dim_embed, eps=norm_eps)
self.lnorm = AdaLNZero(dim_embed, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_embed, eps=norm_eps)
self.noise_conditioning = LinearNormConditioning(
latent_space_dim=dim_embed, dtype=attention_dtype
)
Expand Down Expand Up @@ -317,8 +321,13 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None):

# Handle ada_ln_aux conditioning
if self.is_dit:
x = self.lnorm(x)
x, gate = self.noise_conditioning(x, emb)
if self.dit_is_cond:
x, cond_gate = self.lnorm(x, ada_ln_aux)
else:
x = self.lnorm(x)
cond_gate = 1
x, noise_gate = self.noise_conditioning(x, emb)
gate = cond_gate * noise_gate
else:
x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x)

Expand Down Expand Up @@ -586,6 +595,7 @@ def __init__(
attention_dtype=torch.bfloat16,
with_2d_rope=False,
is_dit=False, # should only be True for diffusion model
dit_is_cond = False, # whether the attention is used for conditioning in the diffusion model (as opposed to denoising). Should only be True for cross attention layers in the diffusion model, and will control whether ada_ln_aux is applied to the input or output of the attention layer
use_xsa=False,
):
super(MultiSelfAttentionHead, self).__init__()
Expand All @@ -607,11 +617,13 @@ def __init__(
norm = RMSNorm

self.is_dit = is_dit
self.dit_is_cond = dit_is_cond

if is_dit:
assert dim_aux is None, "conditioning not yet implemented for DIT attention"
if dit_is_cond:
assert dim_aux is not None, "For DIT, need to provide dim_aux for ada layer norm"
assert with_residual, "DIT attention should always have residual connection"
self.lnorm = norm(dim_embed, eps=norm_eps)
self.lnorm = AdaLNZero(dim_embed, dim_aux, norm_eps=norm_eps) if dim_aux is not None else norm(dim_embed, eps=norm_eps)
self.noise_conditioning = LinearNormConditioning(
latent_space_dim=dim_embed, dtype=attention_dtype
) # TODO: Do I need to pass dtype?
Expand Down Expand Up @@ -650,8 +662,13 @@ def forward(self, x, coords=None, emb=None, ada_ln_aux=None):

# Handle ada_ln_aux conditioning
if self.is_dit:
x = self.lnorm(x)
x, gate = self.noise_conditioning(x, emb)
if self.dit_is_cond:
x, cond_gate = self.lnorm(x, ada_ln_aux)
else:
x = self.lnorm(x)
cond_gate = 1
x, noise_gate = self.noise_conditioning(x, emb)
gate = cond_gate * noise_gate
else:
x = self.lnorm(x, ada_ln_aux) if ada_ln_aux is not None else self.lnorm(x)

Expand Down Expand Up @@ -699,12 +716,14 @@ def __init__(
qk_norm_type=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
is_dit=False,
):
super(MultiCrossAttentionHead, self).__init__()

self.num_heads = num_heads
self.with_residual = with_residual
self.with_flash = with_flash
self.is_dit = is_dit

if norm_type == "LayerNorm":
norm = partial(torch.nn.LayerNorm, elementwise_affine=False, eps=norm_eps)
Expand All @@ -714,7 +733,14 @@ def __init__(
assert dim_embed_q % num_heads == 0
self.dim_head_proj = dim_embed_q // num_heads if dim_head_proj is None else dim_head_proj

self.lnorm_in_q = norm(dim_embed_q, eps=norm_eps)
if is_dit:
assert with_residual
self.lnorm_in_q = norm(dim_embed_q, eps=norm_eps)
self.noise_conditioning = LinearNormConditioning(
latent_space_dim=dim_embed_q, dtype=attention_dtype
)
else:
self.lnorm_in_q = norm(dim_embed_q, eps=norm_eps)
self.lnorm_in_kv = norm(dim_embed_kv, eps=norm_eps)

self.proj_heads_q = torch.nn.Linear(dim_embed_q, num_heads * self.dim_head_proj, bias=False)
Expand Down Expand Up @@ -744,10 +770,16 @@ def __init__(
self.softmax = torch.nn.Softmax(dim=-1)

#########################################
def forward(self, x_q, x_kv):
def forward(self, x_q, x_kv, emb=None):
if self.with_residual:
x_q_in = x_q
x_q, x_kv = self.lnorm_in_q(x_q), self.lnorm_in_kv(x_kv)

if self.is_dit:
x_q = self.lnorm_in_q(x_q)
x_q, gate = self.noise_conditioning(x_q, emb)
else:
x_q = self.lnorm_in_q(x_q)
x_kv = self.lnorm_in_kv(x_kv)

# project onto heads and q,k,v and
# ensure these are 4D tensors as required for flash attention
Expand All @@ -763,6 +795,6 @@ def forward(self, x_q, x_kv):

outs = self.dropout(self.proj_out(outs.flatten(-2, -1)))
if self.with_residual:
outs = x_q_in + outs
outs = x_q_in + outs * gate if self.is_dit else x_q_in + outs

return outs
Loading
Loading