16 v2 global embedding#49
Conversation
…predict. Added use_device to train
SarahAlidoost
left a comment
There was a problem hiding this comment.
@meiertgrootes thanks, the implementation looks good 👍 I left some comments. Most of them are related to code style. If something isnot clear, let me know.
Also, please consider running ruff as it can fixes things automatically and helps saving time in reviewing:
pip install ruff
ruff check --fix your_script.py # this fixes/shows errors
ruff format --check your_script.py --diff # this shows formatting issues
|
|
||
|
|
||
|
|
| #geo_pos_tensor = self.sh_geo_pos[i: i + ph, j: j + pw] # (H,W, sh_emb_dim) -> (pH, pW, sh_embed_dim) | ||
|
|
||
|
|
There was a problem hiding this comment.
| #geo_pos_tensor = self.sh_geo_pos[i: i + ph, j: j + pw] # (H,W, sh_emb_dim) -> (pH, pW, sh_embed_dim) | |
| lon_patch = self.lon_coords[j : j + pw] # (W,) -> (pW,) | ||
|
|
||
| #get patch geo pos embedding | ||
| #geo_pos_embedding_tensor = compute_patch_geo_pos_embedding(geo_pos_tensor, lat_patch) |
There was a problem hiding this comment.
| #geo_pos_embedding_tensor = compute_patch_geo_pos_embedding(geo_pos_tensor, lat_patch) |
| geo_pos_embedding_tensor = self.patch_geo_embeddings[idx] | ||
|
|
||
| #get scale feature for patch | ||
| #scale_feature_tensor = compute_patch_scale_features(lat_patch, lon_patch) # -> (10,) |
There was a problem hiding this comment.
| #scale_feature_tensor = compute_patch_scale_features(lat_patch, lon_patch) # -> (10,) |
| #self.spatial_pe = SpatialPositionalEncoding2D( | ||
| # embed_dim=embed_dim, max_H=max_H, max_W=max_W | ||
| #) |
There was a problem hiding this comment.
| #self.spatial_pe = SpatialPositionalEncoding2D( | |
| # embed_dim=embed_dim, max_H=max_H, max_W=max_W | |
| #) |
There was a problem hiding this comment.
can you please remove "SpatialPositionalEncoding2D" from the script since we are not using it anymore.
|
|
||
| # east-west extent | ||
| dx = earth_radius * cos_lat_c * dlon | ||
| dx_pix = dx/max(lon_ext -1, 1) |
There was a problem hiding this comment.
| dx_pix = dx/max(lon_ext -1, 1) | |
| dx_pix = dx / max(lon_ext -1, eps) |
|
|
||
|
|
||
|
|
||
| def _set_geo_pos_table(self, sh_pos_table: str): |
There was a problem hiding this comment.
this function doesnot return anything, but in __init__ method, it is called as self.geo_pos_table = self._set_geo_pos_table(sh_pos_table). As a result, self.geo_pos_table will be None.
| self.sh_order_L = sh_order_L | ||
|
|
||
|
|
||
|
|
||
| # Check that the input data has the expected dimensions |
There was a problem hiding this comment.
| self.sh_order_L = sh_order_L | |
| # Check that the input data has the expected dimensions | |
| self.sh_order_L = sh_order_L | |
| # Check that the input data has the expected dimensions |
| "land_mask_patch": land_tensor, # (pH,pW) True=Land | ||
| "daily_timef_patch": daily_timef_tensor, #(M, T=31, 2) | ||
| "padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded | ||
| #"sh_geo_pos_patch": geo_pos_tensor, # (pH, pW, sh_embed_dim) |
There was a problem hiding this comment.
| #"sh_geo_pos_patch": geo_pos_tensor, # (pH, pW, sh_embed_dim) |
| geo_emb = geo_emb[:, None, None, :] # (B,1,1,E) | ||
|
|
||
| x = agg_latent + geo_emb # (B, M, Hp*Wp, E) |
There was a problem hiding this comment.
| geo_emb = geo_emb[:, None, None, :] # (B,1,1,E) | |
| x = agg_latent + geo_emb # (B, M, Hp*Wp, E) | |
| # Broadcasting: same geo embedding for all M months and all Hp*Wp locations | |
| # we use weighted mean patch embedding, see `geo_embedding_utils.py` | |
| geo_emb = geo_emb[:, None, None, :] # (B,1,1,E) | |
| x = agg_latent + geo_emb # (B, M, Hp*Wp, E) |
SarahAlidoost
left a comment
There was a problem hiding this comment.
@meiertgrootes thanks, the implementation looks good 👍 I left some comments. Most of them are related to code style. If something isnot clear, let me know.
Also, please consider running ruff as it can fixes things automatically and helps saving time in reviewing:
pip install ruff
ruff check --fix your_script.py # this fixes/shows errors
ruff format --check your_script.py --diff # this shows formatting issues
rogerkuou
left a comment
There was a problem hiding this comment.
Hi @meiertgrootes , besides Sarah's comments, I only have two very minor comments from my side. So I will already approve this PR. Please go ahead and merge after adapting Sarah's comment.
| batch["land_mask_patch"].to(device, non_blocking=use_cuda), | ||
| batch["geo_pos_embedding_patch"].to(device, non_blocking=use_cuda), | ||
| batch["scale_feature_patch"].to(device, non_blocking=use_cuda), | ||
| batch["padded_days_mask"].to(device, non_blocking=use_cuda) , |
There was a problem hiding this comment.
| batch["padded_days_mask"].to(device, non_blocking=use_cuda) , | |
| batch["padded_days_mask"].to(device, non_blocking=use_cuda), |
| @@ -35,7 +35,7 @@ | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "data_folder = Path(\"./eso4clima\")\n", | |||
| "data_folder = Path(\"/Users/mwgrootes/Projects/REPOS/ESO4CLIMA/data/output\") #(\"./eso4clima\")\n", | |||
There was a problem hiding this comment.
For the generality of the note book, maybe let's keep the original relative path ./eso4clima?
This pull request replaces #48 which was corrupted by mal-configured local github setup
This pull request adds fully sphere aware geo position and scale encoding for patches.
Geo position is encoded using real-valued spherical harmonics as basis.
To create embeddings real-valued spherical harmonics at the lat/lon positions of pixels of the input data (i.e. native resolution) are calculated up to a user-defined order L. This results in an (L+1)^2 dimensional embedding vector. L ~ 10 should be fine.
Subsequently a sphere aware area-weighted PCA is performed on the native resolution SH embdding grid, with the requirment that the target diemension for the PCA (the sh_embed_dim) be smaller than (L+1)^2. The ranked PCA components up to sh_embed_dim are retained to be used as basis functions. The SH embedding vectors are then reprojected to this basis and scaled to zero mean and unit variance with tanh based soft-clipping at ~3 sigma to suppress pathological outliers in high order harmonics.
For each patch, a patch position embedding is then constructed as the area weighted mean of the token/pixel embeddings in the patch.
In addition for each patch a scale embedding is constructed consisting of: the patch physical extent in lat and lon directions, the patch area, the anisotropy of the patch extent, the pixel scale in phyical units [m] in lat/lon, anisotropy and isotropized linear scale, as well as finally effective harmonic order cutoff in lat/lon.
Both embeddings are precalculated once for all patches
The 10-dimensional scale embedding is concatenated with the geo position embedding for each patch. During training an trainable MLP is used to project the concatented embeddings into the desired embedding diemension for additive incorporation.
Note, it makes sense to choose the hidden diemension for the projection larger than bith the dimension of the concatenated embeddings, as well as the target embedding dimension.