-
Notifications
You must be signed in to change notification settings - Fork 0
16 v2 global embedding #49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cd6f099
31b2d3c
ceacbd4
dd05e43
431df27
0f92c94
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2,6 +2,8 @@ | |||||
|
|
||||||
| import numpy as np | ||||||
| from .utils import add_month_day_dims, calc_stats | ||||||
| from .geo_embedding_utils import calculate_SH_geo_pos_embeddings, compute_patch_geo_pos_embedding | ||||||
| from .geo_embedding_utils import compute_patch_scale_features | ||||||
| import xarray as xr | ||||||
| import torch | ||||||
| from torch.utils.data import Dataset | ||||||
|
|
@@ -20,13 +22,21 @@ def __init__( | |||||
| spatial_dims: Tuple[str, str] = ("lat", "lon"), | ||||||
| patch_size: Tuple[int, int] = (16, 16), # (lat, lon) | ||||||
| stride: Tuple[int, int] = None, | ||||||
| sh_pos_table: str = None, | ||||||
| sh_embed_dim: int = 96, # sh_embed_dim should <= (sh_order_L + 1)**2 | ||||||
| sh_order_L: int = 10, | ||||||
| ): | ||||||
| self.spatial_dims = spatial_dims | ||||||
| self.patch_size = patch_size | ||||||
| self.daily_da = daily_da | ||||||
| self.monthly_da = monthly_da | ||||||
| self.stride = stride if stride is not None else patch_size | ||||||
|
|
||||||
| self.sh_embed_dim = sh_embed_dim | ||||||
| self.sh_order_L = sh_order_L | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
| # Check that the input data has the expected dimensions | ||||||
| if time_dim not in daily_da.dims or time_dim not in monthly_da.dims: | ||||||
| raise ValueError(f"Time dimension '{time_dim}' not found in input data") | ||||||
|
|
@@ -84,6 +94,21 @@ def __init__( | |||||
| H, W = self.daily_np.shape[2], self.daily_np.shape[3] | ||||||
| self.patch_indices = self._compute_patch_indices(H, W) | ||||||
|
|
||||||
| # Precompute geoposition and scale embeddings for patches | ||||||
| self.geo_pos_table = self._set_geo_pos_table(sh_pos_table) | ||||||
| self.patch_geo_embeddings, self.patch_scale_features = self._compute_geoscalepatch_embeddings() | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Comment on lines
+100
to
+102
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| def _set_geo_pos_table(self, sh_pos_table: str): | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function doesnot |
||||||
| """ Calculate or retrieve spherical harmonics based geo position embeddings.""" | ||||||
| if sh_pos_table is None: | ||||||
| self.sh_geo_pos = calculate_SH_geo_pos_embeddings(self.lat_coords, | ||||||
| self.lon_coords, self.sh_order_L, self.sh_embed_dim) | ||||||
| else: | ||||||
| #load then set embed dim and sh order L from here | ||||||
| raise(RuntimeError('load method not implemented')) | ||||||
|
|
||||||
| def _compute_patch_indices(self, H: int, W: int) -> list: | ||||||
| """Generate patch start indices with coverage warning (overlap support).""" | ||||||
| ph, pw = self.patch_size | ||||||
|
|
@@ -126,6 +151,27 @@ def _compute_patch_indices(self, H: int, W: int) -> list: | |||||
| print(f"Overlap: {overlap_h} pixels (height), {overlap_w} pixels (width)") | ||||||
|
|
||||||
| return [(i, j) for i in i_starts for j in j_starts] | ||||||
|
|
||||||
| def _compute_geoscalepatch_embeddings(self): | ||||||
| patch_geo_embeddings = [] | ||||||
| patch_scale_features = [] | ||||||
|
|
||||||
| for i, j in self.patch_indices: | ||||||
| ph, pw = self.patch_size | ||||||
| geo_pos_tensor = self.sh_geo_pos[i:i+ph, j:j+pw,] | ||||||
| lat_patch = self.lat_coords[i:i+ph] | ||||||
| lon_patch = self.lon_coords[j:j+pw] | ||||||
|
|
||||||
| geo_emb = compute_patch_geo_pos_embedding(geo_pos_tensor,lat_patch,) | ||||||
| scale_feat = compute_patch_scale_features( lat_patch, lon_patch,) | ||||||
|
|
||||||
| patch_geo_embeddings.append(geo_emb) | ||||||
| patch_scale_features.append(scale_feat) | ||||||
|
|
||||||
| patch_geo_embeddings = torch.stack(patch_geo_embeddings) | ||||||
| patch_scale_features = torch.stack(patch_scale_features ) | ||||||
|
|
||||||
| return patch_geo_embeddings, patch_scale_features | ||||||
|
|
||||||
| def __len__(self): | ||||||
| return len(self.patch_indices) | ||||||
|
|
@@ -140,49 +186,71 @@ def __getitem__(self, idx): | |||||
| ph, pw = self.patch_size | ||||||
|
|
||||||
| # Extract spatial patch via numpy slicing — faster than xarray indexing | ||||||
| daily_patch = self.daily_np[:, :, i : i + ph, j : j + pw] # (M, T, H, W) | ||||||
| monthly_patch = self.monthly_np[:, i : i + ph, j : j + pw] # (M, H, W) | ||||||
| daily_patch = self.daily_np[:, :, i : i + ph, j : j + pw] # (M, T, H, W) -> (M,T,pH, pW) | ||||||
| monthly_patch = self.monthly_np[:, i : i + ph, j : j + pw] # (M, H, W) -> (M, pH, pW) | ||||||
| daily_nan_mask = self.daily_nan_mask[ | ||||||
| :, :, i : i + ph, j : j + pw | ||||||
| ] # (M, T, H, W) | ||||||
| ] # (M, T, H, W) -> (M, T, pH, pW) | ||||||
|
|
||||||
| if self.land_mask_np is not None: | ||||||
| land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W) | ||||||
| land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W) -> (pH,pW) | ||||||
| land_tensor = torch.from_numpy(land_patch.copy()).bool() | ||||||
| else: | ||||||
| land_tensor = torch.zeros(ph, pw, dtype=torch.bool) | ||||||
|
|
||||||
| #geo_pos_tensor = self.sh_geo_pos[i: i + ph, j: j + pw] # (H,W, sh_emb_dim) -> (pH, pW, sh_embed_dim) | ||||||
|
|
||||||
|
|
||||||
|
Comment on lines
+201
to
+203
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| # Convert to tensors (from_numpy is zero-copy on contiguous arrays) | ||||||
| # (1, M, T, H, W) | ||||||
| # (1, M, T, pH, pW) | ||||||
| daily_tensor = torch.from_numpy(daily_patch).float().unsqueeze(0) | ||||||
| # (M, H, W) | ||||||
| # (M, pH, pW) | ||||||
| monthly_tensor = torch.from_numpy(monthly_patch).float() | ||||||
| # (1, M, T, H, W) | ||||||
| # (1, M, T, pH, pW) | ||||||
| daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0) | ||||||
| # ( M, T, 2) | ||||||
| daily_timef_tensor = torch.from_numpy(self.daily_timef_np).float() | ||||||
|
|
||||||
| # daily_mask: NaN locations that are NOT land | ||||||
| # Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W) | ||||||
| # Reshape land_tensor for broadcasting: (pH, pW) → (1, 1, 1, pH, pW) | ||||||
| daily_mask_tensor = daily_nan_mask & ( | ||||||
| ~land_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(0) | ||||||
| ) | ||||||
|
|
||||||
| # Extract lat/lon coordinates for this patch | ||||||
| lat_patch = self.lat_coords[i : i + ph] | ||||||
| lon_patch = self.lon_coords[j : j + pw] | ||||||
| lat_patch = self.lat_coords[i : i + ph] # (H,) -> (pH,) | ||||||
| lon_patch = self.lon_coords[j : j + pw] # (W,) -> (pW,) | ||||||
|
|
||||||
| #get patch geo pos embedding | ||||||
| #geo_pos_embedding_tensor = compute_patch_geo_pos_embedding(geo_pos_tensor, lat_patch) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| geo_pos_embedding_tensor = self.patch_geo_embeddings[idx] | ||||||
|
|
||||||
| #get scale feature for patch | ||||||
| #scale_feature_tensor = compute_patch_scale_features(lat_patch, lon_patch) # -> (10,) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| scale_feature_tensor = self.patch_scale_features[idx] | ||||||
|
|
||||||
| #create tensors to pass sh embedding dimension, harmonic order, and scale feature dim | ||||||
| sh_embed_dim = torch.tensor(self.sh_embed_dim) | ||||||
| harmonic_order = torch.tensor(self.sh_order_L) | ||||||
| scale_f_dim = torch.tensor(len(scale_feature_tensor)) | ||||||
|
|
||||||
| # Convert to tensors | ||||||
| return { | ||||||
| "daily_patch": daily_tensor, # (C=1, M, T=31, H, W) | ||||||
| "monthly_patch": monthly_tensor, # (M, H, W) | ||||||
| "daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, H, W) | ||||||
| "land_mask_patch": land_tensor, # (H,W) True=Land | ||||||
| "daily_patch": daily_tensor, # (C=1, M, T=31, pH, pW) | ||||||
| "monthly_patch": monthly_tensor, # (M, pH, pW) | ||||||
| "daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, pH, pW) | ||||||
| "land_mask_patch": land_tensor, # (pH,pW) True=Land | ||||||
| "daily_timef_patch": daily_timef_tensor, #(M, T=31, 2) | ||||||
| "padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded | ||||||
| #"sh_geo_pos_patch": geo_pos_tensor, # (pH, pW, sh_embed_dim) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| "scale_feature_patch": scale_feature_tensor, #(10,) | ||||||
| "geo_pos_embedding_patch": geo_pos_embedding_tensor, #(sh_embed_dim,) | ||||||
| "sh_embed_dim": sh_embed_dim, | ||||||
| "harmonic_order": harmonic_order, | ||||||
| "scale_f_dim":scale_f_dim, | ||||||
| "coords": (i, j), | ||||||
| "lat_patch": lat_patch, # (H,) | ||||||
| "lon_patch": lon_patch, # (W,) | ||||||
| "lat_patch": lat_patch, # (pH,) | ||||||
| "lon_patch": lon_patch, # (pW,) | ||||||
| } | ||||||
|
|
||||||
| def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]: | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.