From 35adb0b821376fdd15c278458b803aeb2851620b Mon Sep 17 00:00:00 2001 From: buschow1 Date: Wed, 20 May 2026 13:24:15 +0200 Subject: [PATCH 1/3] initial casestudy implementation --- .../tropical_cyclones/TC_casestudy_main.py | 94 ++++++++++ .../tropical_cyclones/TC_config.yml | 17 ++ .../tropical_cyclones/cyclone_finder.py | 161 ++++++++++++++++++ .../tropical_cyclones/cyclone_plots.py | 58 +++++++ pyproject.toml | 4 +- 5 files changed, 333 insertions(+), 1 deletion(-) create mode 100644 packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py create mode 100644 packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_config.yml create mode 100644 packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py create mode 100644 packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py new file mode 100644 index 000000000..8996701d0 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py @@ -0,0 +1,94 @@ +from cyclone_finder import cyclone_finder, cyclone,track_error, track_cyclones, track2pandas, cyclones_in_ds, wrap_lon +import xarray as xr +import numpy as np +from functools import cached_property +from cyclone_plots import track_eval_plot, track_snapshots +import pandas as pd +import matplotlib.pyplot as plt +from pathlib import Path +from omegaconf import OmegaConf + +class TC_casestudy(): + + def __init__(self, cfg: dict): + self.cfg = cfg + self.selected_storm = cyclone( + wind=0, + pressure=0, + lon=cfg.selected_storm.lon, + lat=cfg.selected_storm.lat, + time=np.datetime64(cfg.selected_storm.time) + ) + self.finder = cyclone_finder( + sigma = cfg.tracking_params.laplace_size, + th_LoG= cfg.tracking_params.laplace_threshold, + th_pressure=cfg.tracking_params.pressure_threshold, + th_wind=cfg.tracking_params.wind_threshold, + min_distance=cfg.tracking_params.peak_separation + ) + self.outpath = Path(cfg.outpath) + + @cached_property + def datasets(self): + infiles = { + k: f"{self.cfg.inpath}{k}_{self.cfg.init_time}_{self.cfg.runid}_ERA5.nc" + for k in ("target","prediction") + } + datasets = { + k: wrap_lon(xr.open_dataset(f)).sel(latitude=slice(self.cfg.latmin,self.cfg.latmax)) + for k,f in infiles.items() + } + return datasets + + @cached_property + def cyclones(self): + times = self.datasets["target"].valid_time.values + cyclones = { + k: [ cyclones_in_ds(ds, self.finder, time=t) for t in times ] + for k,ds in self.datasets.items() + } + return cyclones + + @cached_property + def tracks(self): + tracks = { + k: track_cyclones(d, self.cfg.tracking_params.merge_distance) + for k,d in self.cyclones.items() + } + return tracks + + @cached_property + def matched_tracks(self): + times = self.datasets["target"].valid_time.values + storm_index = np.argmin(np.abs(times - self.selected_storm.time)) + matched_stroms = { + k: self.selected_storm.match(x[storm_index]) + for k,x in self.cyclones.items() + } + matched_tracks = { + k: track2pandas(d.subset(matched_stroms[k])) + for k,d in self.tracks.items() + } + return matched_tracks + + def plot(self): + self.outpath.mkdir(exist_ok=True) + # evaluation plot + evalfile = f"{self.outpath}/{self.cfg.runid}_cyclone_{self.cfg.init_time}.png" + fig, axs = track_eval_plot(self.matched_tracks) + init_time = self.datasets["target"].forecast_reference_time.values + fig.suptitle(f"forecast initialized {init_time}") + plt.savefig(evalfile) + + # example maps + snapshotfile = f"{self.outpath}/{self.cfg.runid}_cyclone_{self.cfg.init_time}_snapshots.png" + track_snapshots(self.matched_tracks, self.datasets) + plt.savefig(snapshotfile) + +def main(): + cfg = OmegaConf.load("TC_config.yml") + casestudy = TC_casestudy(cfg) + casestudy.plot() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_config.yml b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_config.yml new file mode 100644 index 000000000..9e756b814 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_config.yml @@ -0,0 +1,17 @@ +runid : "i0xr7z48" +init_time : "2023-10-07T00" +inpath : "/p/project1/weatherai/buschow1/wegen_export/cyclones/" +outpath : "./plots/" +latmin: -30 +latmax: 30 +selected_storm : + lon : 154.7 + lat : 9.6 + time : "2023-10-07T00:00" +tracking_params: + laplace_size : 2 + laplace_threshold : 0 + pressure_threshold : 103000 + wind_threshold : 0 + peak_separation : 5 + merge_distance: 300 \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py new file mode 100644 index 000000000..94969175f --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py @@ -0,0 +1,161 @@ +import pandas as pd +import numpy as np +import xarray as xr +from typing import List +from tqdm import tqdm +from scipy.ndimage import gaussian_laplace, maximum_filter +from scipy.cluster.hierarchy import DisjointSet +from skimage.feature import peak_local_max +from dataclasses import dataclass +from sklearn.metrics.pairwise import haversine_distances + +@dataclass(order=True, frozen=True) +class cyclone: + wind: float + pressure: float + lon: float + lat: float + ID: str | None = None + time: np.datetime64 | None = None + + def dist_to(self, other: "cyclone") -> float: + R = 6371.0 + p1 = [ np.deg2rad(deg) for deg in (self.lat, self.lon) ] + p2 = [ np.deg2rad(deg) for deg in (other.lat, other.lon) ] + angle = haversine_distances( + X = np.array(p1).reshape(1,-1), + Y = np.array(p2).reshape(1,-1) + ) + return R*angle + + def match(self, cyclones, maxdist=3000) -> "cyclone": + dists = [self.dist_to(other) for other in cyclones] + if min(dists) < maxdist: + return cyclones[np.argmin(dists)] + else: + return None + +class cyclone_finder(): + def __init__(self, sigma: float = 2, th_LoG: float = 30, th_pressure: float = 101000, th_wind: float = 10, min_distance: float = 5): + ''' + Try finding cyclones with simple blob detection + plus some heuristic filter criteria + Attributes + ---------- + sigma: Gauss standard deviation. The zeros of the laplace filter + are at sqrt(2)*sigma distance from the center + th_LoG: minimum value of the filtered field + th_pressure: maxmimum pressure value + th_wind: minimum wind speed + min_distance: minimum distance between peaks in number of gridpoints + ''' + self.sigma = sigma + self.th_LoG = th_LoG + self.th_pressure = th_pressure + self.th_wind = th_wind + self.min_distance = min_distance + + def filter(self, image): + return gaussian_laplace(image, sigma=self.sigma) + + def mask(self, pressure, windmax): + pressuremask = (pressure self.th_wind + return pressuremask & windmask + + def find_cyclones(self,pressure, wind, windmaxsize=5, timestamp=None): + # apply the LoG filter to pressure + filtered = self.filter(pressure) + # find candidate maxima + candidates = peak_local_max( + filtered, + threshold_abs=self.th_LoG, + min_distance=self.min_distance + ) + # apply mask + windmax = maximum_filter(wind.values, size=windmaxsize) + mask = self.mask(pressure, windmax)[candidates[:,0],candidates[:,1]] + cyclones = candidates[mask,:] + res = [ + cyclone( + lon = pressure.longitude.values[y], + lat = pressure.latitude.values[x], + wind = windmax[x,y], + pressure= pressure.values[x,y], + time = timestamp + ) + for x,y in zip(cyclones[:,0],cyclones[:,1]) + ] + return res + +def track_cyclones(timesteps: List[List["cyclone"]], merge_distance_km: float = 300): + ''' + Takes a list of lists of cyclones, each top level entry representing one timestep, + returns a DisjointSet where each entry represents a track. + ''' + tracks = DisjointSet() + prev_step = [] + + for step in tqdm(timesteps): + # Add all storms from this timestep + for storm in step: + tracks.add(storm) + + # Build all candidate matches (prev → curr) + candidates = [] + for s_prev in prev_step: + for s_curr in step: + d = s_prev.dist_to(s_curr) + if d <= merge_distance_km: + candidates.append((d, s_prev, s_curr)) + + # Sort by distance (closest first) + candidates.sort(key=lambda x: x[0]) + + # Keep track of which storms have already been matched + used_prev = set() + used_curr = set() + + # Greedy matching: closest pairs first + for dist, s_prev, s_curr in candidates: + if s_prev not in used_prev and s_curr not in used_curr: + tracks.merge(s_prev, s_curr) + used_prev.add(s_prev) + used_curr.add(s_curr) + + prev_step = step + + return tracks + +def track2pandas(track: List["cyclone"]): + return pd.DataFrame( + [ storm.__dict__ for storm in track ] + ).set_index("time").sort_index() + +def cyclones_in_ds(ds, finder, time): + ds_t = ds.sel(valid_time=time) + msl = ds_t.msl + V = np.sqrt(ds_t.u10**2 + ds_t.v10**2) + return finder.find_cyclones(pressure=msl, wind=V, timestamp=time) + +def track_error(tr1, tr2): + R = 6371.0 + coords = [ + np.deg2rad(x.loc[:,["lat","lon"]]) + for x in tr1.align(tr2, join="inner") + ] + angle = haversine_distances( + X = coords[0].values, Y=coords[1].values + ) + distance = pd.DataFrame({"distance":R*np.diag(angle)}, + index = coords[0].index) + all_idx = tr1.index.union(tr2.index) + distance = distance.reindex(all_idx) + + return distance + + +def wrap_lon(ds): + ds["longitude"] = (ds["longitude"] +180) % 360 - 180 + ds = ds.sortby("longitude") + return ds diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py new file mode 100644 index 000000000..f8086ca7b --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py @@ -0,0 +1,58 @@ +import matplotlib.pyplot as plt +import cartopy.crs as ccrs +import xarray as xr +import numpy as np +import pandas as pd +from cyclone_finder import track_error + +def track_eval_plot(matched_tracks): + fig, axs = plt.subplots(2,2,sharex=True, figsize=(10,6)) + fig.delaxes(axs[0,0]) + axs[0,0] = fig.add_subplot(2, 2, 1, projection=ccrs.PlateCarree()) + axs[0,0].coastlines() + axs[0,0].set_title("storm tracks") + axs[0,1].set_title("track error in km") + axs[1,0].set_title("core pressure in Pa") + axs[1,1].set_title("max wind speed in m/s") + track_error( *matched_tracks.values() ).plot(ax=axs[0,1]) + for lab, track in matched_tracks.items(): + track.plot(x="lon",y="lat", ax=axs[0,0], label=lab) + track.plot(y="pressure", ax=axs[1,0], label=lab) + track.plot(y="wind", ax=axs[1,1], label=lab) + return fig, axs + +def bounding_box(matched_tracks, pad=2): + all_lons = pd.concat([matched_tracks["target"]["lon"], matched_tracks["prediction"]["lon"]]) + all_lats = pd.concat([matched_tracks["target"]["lat"], matched_tracks["prediction"]["lat"]]) + lon_min = all_lons.min() - pad + lon_max = all_lons.max() + pad + lat_min = all_lats.min() - pad + lat_max = all_lats.max() + pad + bbox = (lon_min, lon_max, lat_min, lat_max) + return bbox + +def track_snapshots(matched_tracks, datasets, skip=5): + bbox = bounding_box(matched_tracks) + all_steps = matched_tracks["target"].index.union(matched_tracks["prediction"].index) + selsteps = np.arange(0,len(all_steps),skip) + plotdat = xr.concat( datasets.values(), dim=datasets.keys() ).isel(valid_time=selsteps) + plotdat = plotdat.sel(longitude=slice(bbox[0], bbox[1]), latitude=slice(bbox[2], bbox[3])) + speed = np.sqrt(plotdat.u10**2+plotdat.v10**2) + p = speed.plot( + row="concat_dim", col="valid_time", + subplot_kws=dict(projection=ccrs.PlateCarree()) + ) + for ax in p.axs.flatten(): + ax.coastlines() + ax.set_extent(bbox) + for i,s in enumerate(all_steps[selsteps]): + leadtime = plotdat.forecast_period[i].values / np.timedelta64(1, "h") + p.axs[0,i].set_title(s) + p.axs[1,i].set_title(f"{leadtime}h forecast") + for j,tr in enumerate(matched_tracks.values()): + if s in tr.index: + tr.loc[[s]].plot.scatter( + x="lon",y="lat", ax=p.axs[j,i], + color="tab:red", marker="x",s=100 + ) + return p \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 00103cb8c..499829189 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,9 @@ dependencies = [ "anemoi-datasets", "weathergen-common", "weathergen-evaluate", - "weathergen-readers-extra" + "weathergen-readers-extra", + "scikit-image>=0.26.0", + "scikit-learn>=1.8.0", ] From 0c21d58934c2068f720aa3c13e3ab92f6a889c92 Mon Sep 17 00:00:00 2001 From: buschow1 Date: Wed, 20 May 2026 13:39:11 +0200 Subject: [PATCH 2/3] initial linting --- .../tropical_cyclones/TC_casestudy_main.py | 81 ++++++------ .../tropical_cyclones/cyclone_finder.py | 125 +++++++++--------- .../tropical_cyclones/cyclone_plots.py | 53 ++++---- 3 files changed, 133 insertions(+), 126 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py index 8996701d0..c86c5efcb 100644 --- a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py @@ -1,59 +1,66 @@ -from cyclone_finder import cyclone_finder, cyclone,track_error, track_cyclones, track2pandas, cyclones_in_ds, wrap_lon -import xarray as xr -import numpy as np from functools import cached_property -from cyclone_plots import track_eval_plot, track_snapshots -import pandas as pd -import matplotlib.pyplot as plt from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from cyclone_finder import ( + Cyclone, + CycloneFinder, + cyclones_in_ds, + track2pandas, + track_cyclones, + wrap_lon, +) +from cyclone_plots import track_eval_plot, track_snapshots from omegaconf import OmegaConf -class TC_casestudy(): +class TcCaseStudy: def __init__(self, cfg: dict): self.cfg = cfg - self.selected_storm = cyclone( - wind=0, - pressure=0, - lon=cfg.selected_storm.lon, - lat=cfg.selected_storm.lat, - time=np.datetime64(cfg.selected_storm.time) + self.selected_storm = Cyclone( + wind=0, + pressure=0, + lon=cfg.selected_storm.lon, + lat=cfg.selected_storm.lat, + time=np.datetime64(cfg.selected_storm.time), ) - self.finder = cyclone_finder( - sigma = cfg.tracking_params.laplace_size, - th_LoG= cfg.tracking_params.laplace_threshold, + self.finder = CycloneFinder( + sigma=cfg.tracking_params.laplace_size, + th_laplace=cfg.tracking_params.laplace_threshold, th_pressure=cfg.tracking_params.pressure_threshold, - th_wind=cfg.tracking_params.wind_threshold, - min_distance=cfg.tracking_params.peak_separation + th_wind=cfg.tracking_params.wind_threshold, + min_distance=cfg.tracking_params.peak_separation, ) self.outpath = Path(cfg.outpath) - + @cached_property def datasets(self): - infiles = { - k: f"{self.cfg.inpath}{k}_{self.cfg.init_time}_{self.cfg.runid}_ERA5.nc" - for k in ("target","prediction") + infiles = { + k: f"{self.cfg.inpath}{k}_{self.cfg.init_time}_{self.cfg.runid}_ERA5.nc" + for k in ("target", "prediction") } - datasets = { - k: wrap_lon(xr.open_dataset(f)).sel(latitude=slice(self.cfg.latmin,self.cfg.latmax)) - for k,f in infiles.items() + datasets = { + k: wrap_lon(xr.open_dataset(f)).sel(latitude=slice(self.cfg.latmin, self.cfg.latmax)) + for k, f in infiles.items() } return datasets @cached_property def cyclones(self): times = self.datasets["target"].valid_time.values - cyclones = { - k: [ cyclones_in_ds(ds, self.finder, time=t) for t in times ] - for k,ds in self.datasets.items() + cyclones = { + k: [cyclones_in_ds(ds, self.finder, time=t) for t in times] + for k, ds in self.datasets.items() } return cyclones - + @cached_property def tracks(self): tracks = { - k: track_cyclones(d, self.cfg.tracking_params.merge_distance) - for k,d in self.cyclones.items() + k: track_cyclones(d, self.cfg.tracking_params.merge_distance) + for k, d in self.cyclones.items() } return tracks @@ -62,12 +69,10 @@ def matched_tracks(self): times = self.datasets["target"].valid_time.values storm_index = np.argmin(np.abs(times - self.selected_storm.time)) matched_stroms = { - k: self.selected_storm.match(x[storm_index]) - for k,x in self.cyclones.items() + k: self.selected_storm.match(x[storm_index]) for k, x in self.cyclones.items() } matched_tracks = { - k: track2pandas(d.subset(matched_stroms[k])) - for k,d in self.tracks.items() + k: track2pandas(d.subset(matched_stroms[k])) for k, d in self.tracks.items() } return matched_tracks @@ -85,10 +90,12 @@ def plot(self): track_snapshots(self.matched_tracks, self.datasets) plt.savefig(snapshotfile) + def main(): cfg = OmegaConf.load("TC_config.yml") - casestudy = TC_casestudy(cfg) + casestudy = TcCaseStudy(cfg) casestudy.plot() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py index 94969175f..b4223a2d8 100644 --- a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py @@ -1,98 +1,102 @@ -import pandas as pd +from dataclasses import dataclass + import numpy as np -import xarray as xr -from typing import List -from tqdm import tqdm -from scipy.ndimage import gaussian_laplace, maximum_filter +import pandas as pd from scipy.cluster.hierarchy import DisjointSet +from scipy.ndimage import gaussian_laplace, maximum_filter from skimage.feature import peak_local_max -from dataclasses import dataclass from sklearn.metrics.pairwise import haversine_distances +from tqdm import tqdm + @dataclass(order=True, frozen=True) -class cyclone: +class Cyclone: wind: float pressure: float lon: float lat: float ID: str | None = None time: np.datetime64 | None = None - - def dist_to(self, other: "cyclone") -> float: - R = 6371.0 - p1 = [ np.deg2rad(deg) for deg in (self.lat, self.lon) ] - p2 = [ np.deg2rad(deg) for deg in (other.lat, other.lon) ] - angle = haversine_distances( - X = np.array(p1).reshape(1,-1), - Y = np.array(p2).reshape(1,-1) - ) - return R*angle - def match(self, cyclones, maxdist=3000) -> "cyclone": + def dist_to(self, other: "Cyclone") -> float: + r_earth = 6371.0 + p1 = [np.deg2rad(deg) for deg in (self.lat, self.lon)] + p2 = [np.deg2rad(deg) for deg in (other.lat, other.lon)] + angle = haversine_distances(X=np.array(p1).reshape(1, -1), Y=np.array(p2).reshape(1, -1)) + return r_earth * angle + + def match(self, cyclones, maxdist=3000) -> "Cyclone": dists = [self.dist_to(other) for other in cyclones] if min(dists) < maxdist: return cyclones[np.argmin(dists)] else: return None -class cyclone_finder(): - def __init__(self, sigma: float = 2, th_LoG: float = 30, th_pressure: float = 101000, th_wind: float = 10, min_distance: float = 5): - ''' - Try finding cyclones with simple blob detection + +class CycloneFinder: + def __init__( + self, + sigma: float = 2, + th_laplace: float = 30, + th_pressure: float = 101000, + th_wind: float = 10, + min_distance: float = 5, + ): + """ + Try finding cyclones with simple blob detection plus some heuristic filter criteria Attributes ---------- sigma: Gauss standard deviation. The zeros of the laplace filter are at sqrt(2)*sigma distance from the center - th_LoG: minimum value of the filtered field + th_laplace: minimum value of the filtered field th_pressure: maxmimum pressure value th_wind: minimum wind speed min_distance: minimum distance between peaks in number of gridpoints - ''' + """ self.sigma = sigma - self.th_LoG = th_LoG + self.th_laplace = th_laplace self.th_pressure = th_pressure self.th_wind = th_wind self.min_distance = min_distance - + def filter(self, image): return gaussian_laplace(image, sigma=self.sigma) def mask(self, pressure, windmax): - pressuremask = (pressure self.th_wind return pressuremask & windmask - - def find_cyclones(self,pressure, wind, windmaxsize=5, timestamp=None): + + def find_cyclones(self, pressure, wind, windmaxsize=5, timestamp=None): # apply the LoG filter to pressure filtered = self.filter(pressure) # find candidate maxima candidates = peak_local_max( - filtered, - threshold_abs=self.th_LoG, - min_distance=self.min_distance + filtered, threshold_abs=self.th_laplace, min_distance=self.min_distance ) # apply mask windmax = maximum_filter(wind.values, size=windmaxsize) - mask = self.mask(pressure, windmax)[candidates[:,0],candidates[:,1]] - cyclones = candidates[mask,:] + mask = self.mask(pressure, windmax)[candidates[:, 0], candidates[:, 1]] + cyclones = candidates[mask, :] res = [ - cyclone( - lon = pressure.longitude.values[y], - lat = pressure.latitude.values[x], - wind = windmax[x,y], - pressure= pressure.values[x,y], - time = timestamp + Cyclone( + lon=pressure.longitude.values[y], + lat=pressure.latitude.values[x], + wind=windmax[x, y], + pressure=pressure.values[x, y], + time=timestamp, ) - for x,y in zip(cyclones[:,0],cyclones[:,1]) + for x, y in zip(cyclones[:, 0], cyclones[:, 1], strict=False) ] return res -def track_cyclones(timesteps: List[List["cyclone"]], merge_distance_km: float = 300): - ''' + +def track_cyclones(timesteps: list[list["Cyclone"]], merge_distance_km: float = 300): + """ Takes a list of lists of cyclones, each top level entry representing one timestep, - returns a DisjointSet where each entry represents a track. - ''' + returns a DisjointSet where each entry represents a track. + """ tracks = DisjointSet() prev_step = [] @@ -117,7 +121,7 @@ def track_cyclones(timesteps: List[List["cyclone"]], merge_distance_km: float = used_curr = set() # Greedy matching: closest pairs first - for dist, s_prev, s_curr in candidates: + for _dist, s_prev, s_curr in candidates: if s_prev not in used_prev and s_curr not in used_curr: tracks.merge(s_prev, s_curr) used_prev.add(s_prev) @@ -127,28 +131,23 @@ def track_cyclones(timesteps: List[List["cyclone"]], merge_distance_km: float = return tracks -def track2pandas(track: List["cyclone"]): - return pd.DataFrame( - [ storm.__dict__ for storm in track ] - ).set_index("time").sort_index() + +def track2pandas(track: list["Cyclone"]): + return pd.DataFrame([storm.__dict__ for storm in track]).set_index("time").sort_index() + def cyclones_in_ds(ds, finder, time): ds_t = ds.sel(valid_time=time) msl = ds_t.msl - V = np.sqrt(ds_t.u10**2 + ds_t.v10**2) - return finder.find_cyclones(pressure=msl, wind=V, timestamp=time) + v = np.sqrt(ds_t.u10**2 + ds_t.v10**2) + return finder.find_cyclones(pressure=msl, wind=v, timestamp=time) + def track_error(tr1, tr2): - R = 6371.0 - coords = [ - np.deg2rad(x.loc[:,["lat","lon"]]) - for x in tr1.align(tr2, join="inner") - ] - angle = haversine_distances( - X = coords[0].values, Y=coords[1].values - ) - distance = pd.DataFrame({"distance":R*np.diag(angle)}, - index = coords[0].index) + r_earth = 6371.0 + coords = [np.deg2rad(x.loc[:, ["lat", "lon"]]) for x in tr1.align(tr2, join="inner")] + angle = haversine_distances(X=coords[0].values, Y=coords[1].values) + distance = pd.DataFrame({"distance": r_earth * np.diag(angle)}, index=coords[0].index) all_idx = tr1.index.union(tr2.index) distance = distance.reindex(all_idx) @@ -156,6 +155,6 @@ def track_error(tr1, tr2): def wrap_lon(ds): - ds["longitude"] = (ds["longitude"] +180) % 360 - 180 + ds["longitude"] = (ds["longitude"] + 180) % 360 - 180 ds = ds.sortby("longitude") return ds diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py index f8086ca7b..5e77252fc 100644 --- a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py @@ -1,26 +1,28 @@ -import matplotlib.pyplot as plt import cartopy.crs as ccrs -import xarray as xr +import matplotlib.pyplot as plt import numpy as np import pandas as pd +import xarray as xr from cyclone_finder import track_error + def track_eval_plot(matched_tracks): - fig, axs = plt.subplots(2,2,sharex=True, figsize=(10,6)) - fig.delaxes(axs[0,0]) - axs[0,0] = fig.add_subplot(2, 2, 1, projection=ccrs.PlateCarree()) - axs[0,0].coastlines() - axs[0,0].set_title("storm tracks") - axs[0,1].set_title("track error in km") - axs[1,0].set_title("core pressure in Pa") - axs[1,1].set_title("max wind speed in m/s") - track_error( *matched_tracks.values() ).plot(ax=axs[0,1]) + fig, axs = plt.subplots(2, 2, sharex=True, figsize=(10, 6)) + fig.delaxes(axs[0, 0]) + axs[0, 0] = fig.add_subplot(2, 2, 1, projection=ccrs.PlateCarree()) + axs[0, 0].coastlines() + axs[0, 0].set_title("storm tracks") + axs[0, 1].set_title("track error in km") + axs[1, 0].set_title("core pressure in Pa") + axs[1, 1].set_title("max wind speed in m/s") + track_error(*matched_tracks.values()).plot(ax=axs[0, 1]) for lab, track in matched_tracks.items(): - track.plot(x="lon",y="lat", ax=axs[0,0], label=lab) - track.plot(y="pressure", ax=axs[1,0], label=lab) - track.plot(y="wind", ax=axs[1,1], label=lab) + track.plot(x="lon", y="lat", ax=axs[0, 0], label=lab) + track.plot(y="pressure", ax=axs[1, 0], label=lab) + track.plot(y="wind", ax=axs[1, 1], label=lab) return fig, axs + def bounding_box(matched_tracks, pad=2): all_lons = pd.concat([matched_tracks["target"]["lon"], matched_tracks["prediction"]["lon"]]) all_lats = pd.concat([matched_tracks["target"]["lat"], matched_tracks["prediction"]["lat"]]) @@ -31,28 +33,27 @@ def bounding_box(matched_tracks, pad=2): bbox = (lon_min, lon_max, lat_min, lat_max) return bbox + def track_snapshots(matched_tracks, datasets, skip=5): bbox = bounding_box(matched_tracks) all_steps = matched_tracks["target"].index.union(matched_tracks["prediction"].index) - selsteps = np.arange(0,len(all_steps),skip) - plotdat = xr.concat( datasets.values(), dim=datasets.keys() ).isel(valid_time=selsteps) + selsteps = np.arange(0, len(all_steps), skip) + plotdat = xr.concat(datasets.values(), dim=datasets.keys()).isel(valid_time=selsteps) plotdat = plotdat.sel(longitude=slice(bbox[0], bbox[1]), latitude=slice(bbox[2], bbox[3])) - speed = np.sqrt(plotdat.u10**2+plotdat.v10**2) + speed = np.sqrt(plotdat.u10**2 + plotdat.v10**2) p = speed.plot( - row="concat_dim", col="valid_time", - subplot_kws=dict(projection=ccrs.PlateCarree()) + row="concat_dim", col="valid_time", subplot_kws=dict(projection=ccrs.PlateCarree()) ) for ax in p.axs.flatten(): ax.coastlines() ax.set_extent(bbox) - for i,s in enumerate(all_steps[selsteps]): + for i, s in enumerate(all_steps[selsteps]): leadtime = plotdat.forecast_period[i].values / np.timedelta64(1, "h") - p.axs[0,i].set_title(s) - p.axs[1,i].set_title(f"{leadtime}h forecast") - for j,tr in enumerate(matched_tracks.values()): + p.axs[0, i].set_title(s) + p.axs[1, i].set_title(f"{leadtime}h forecast") + for j, tr in enumerate(matched_tracks.values()): if s in tr.index: tr.loc[[s]].plot.scatter( - x="lon",y="lat", ax=p.axs[j,i], - color="tab:red", marker="x",s=100 + x="lon", y="lat", ax=p.axs[j, i], color="tab:red", marker="x", s=100 ) - return p \ No newline at end of file + return p From 377e54e5b8548dcb20f91ac9728d1c6d19f6840c Mon Sep 17 00:00:00 2001 From: buschow1 Date: Wed, 20 May 2026 14:15:23 +0200 Subject: [PATCH 3/3] added some docstrings and comments --- .../tropical_cyclones/TC_casestudy_main.py | 6 ++++ .../tropical_cyclones/TC_config.yml | 18 +++++----- .../tropical_cyclones/cyclone_finder.py | 33 +++++++++++++------ .../tropical_cyclones/cyclone_plots.py | 16 +++++++++ 4 files changed, 54 insertions(+), 19 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py index c86c5efcb..8dad2b9bf 100644 --- a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py @@ -17,6 +17,12 @@ class TcCaseStudy: + """ + Read the cyclone tracker settings, data paths and the target cyclone + from config, then find the matched tracks corresponding to that cyclones + in the prediction and target. + """ + def __init__(self, cfg: dict): self.cfg = cfg self.selected_storm = Cyclone( diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_config.yml b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_config.yml index 9e756b814..17a02da61 100644 --- a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_config.yml +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_config.yml @@ -1,17 +1,17 @@ runid : "i0xr7z48" init_time : "2023-10-07T00" -inpath : "/p/project1/weatherai/buschow1/wegen_export/cyclones/" +inpath : "/p/project1/weatherai/buschow1/wegen_export/cyclones/" outpath : "./plots/" -latmin: -30 +latmin: -30 # TCs are only detected for latmin<=lat<=latmax latmax: 30 -selected_storm : +selected_storm : # the storm you want to analyze lon : 154.7 lat : 9.6 time : "2023-10-07T00:00" tracking_params: - laplace_size : 2 - laplace_threshold : 0 - pressure_threshold : 103000 - wind_threshold : 0 - peak_separation : 5 - merge_distance: 300 \ No newline at end of file + laplace_size : 2 # in units of gridboxes + laplace_threshold : 0 # should be >= 0 to fond low pressure systems + pressure_threshold : 103000 # in Pa + wind_threshold : 0 # in m/s + peak_separation : 5 # in units of gridboxes + merge_distance: 300 # in km \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py index b4223a2d8..8e7641d4a 100644 --- a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd +import xarray as xr from scipy.cluster.hierarchy import DisjointSet from scipy.ndimage import gaussian_laplace, maximum_filter from skimage.feature import peak_local_max @@ -25,9 +26,12 @@ def dist_to(self, other: "Cyclone") -> float: angle = haversine_distances(X=np.array(p1).reshape(1, -1), Y=np.array(p2).reshape(1, -1)) return r_earth * angle - def match(self, cyclones, maxdist=3000) -> "Cyclone": + def match(self, cyclones: list["Cyclone"], maxdist_km: float = 3000) -> "Cyclone": + """ + Select the closest from a set of other cyclones + """ dists = [self.dist_to(other) for other in cyclones] - if min(dists) < maxdist: + if min(dists) < maxdist_km: return cyclones[np.argmin(dists)] else: return None @@ -68,7 +72,7 @@ def mask(self, pressure, windmax): windmask = windmax > self.th_wind return pressuremask & windmask - def find_cyclones(self, pressure, wind, windmaxsize=5, timestamp=None): + def find_cyclones(self, pressure, wind, windmaxsize=5, timestamp=None) -> list["Cyclone"]: # apply the LoG filter to pressure filtered = self.filter(pressure) # find candidate maxima @@ -92,7 +96,7 @@ def find_cyclones(self, pressure, wind, windmaxsize=5, timestamp=None): return res -def track_cyclones(timesteps: list[list["Cyclone"]], merge_distance_km: float = 300): +def track_cyclones(timesteps: list[list["Cyclone"]], merge_distance_km: float = 300) -> DisjointSet: """ Takes a list of lists of cyclones, each top level entry representing one timestep, returns a DisjointSet where each entry represents a track. @@ -132,29 +136,38 @@ def track_cyclones(timesteps: list[list["Cyclone"]], merge_distance_km: float = return tracks -def track2pandas(track: list["Cyclone"]): +def track2pandas(track: list["Cyclone"]) -> pd.DataFrame: return pd.DataFrame([storm.__dict__ for storm in track]).set_index("time").sort_index() -def cyclones_in_ds(ds, finder, time): +def cyclones_in_ds(ds: xr.Dataset, finder: "CycloneFinder", time: np.datetime64) -> list["Cyclone"]: + """ + Find cyclones in a dataset containing at least msl, u10, v10, + at a given timestep, using a given CycloneFinder. + """ ds_t = ds.sel(valid_time=time) msl = ds_t.msl v = np.sqrt(ds_t.u10**2 + ds_t.v10**2) return finder.find_cyclones(pressure=msl, wind=v, timestamp=time) -def track_error(tr1, tr2): +def track_error(track1: pd.DataFrame, track2: pd.DataFrame) -> pd.DataFrame: + """ + Given two tracks as pd.DataFrames, compute their distance in km. + At timesteps where one track is missing, the result is NaN. + """ r_earth = 6371.0 - coords = [np.deg2rad(x.loc[:, ["lat", "lon"]]) for x in tr1.align(tr2, join="inner")] + coords = [np.deg2rad(x.loc[:, ["lat", "lon"]]) for x in track1.align(track2, join="inner")] angle = haversine_distances(X=coords[0].values, Y=coords[1].values) distance = pd.DataFrame({"distance": r_earth * np.diag(angle)}, index=coords[0].index) - all_idx = tr1.index.union(tr2.index) + all_idx = track1.index.union(track2.index) distance = distance.reindex(all_idx) return distance -def wrap_lon(ds): +def wrap_lon(ds: xr.Dataset) -> xr.Dataset: + "Convert longitude from 0...360 to -180...180" ds["longitude"] = (ds["longitude"] + 180) % 360 - 180 ds = ds.sortby("longitude") return ds diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py index 5e77252fc..caaaa8c5f 100644 --- a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py @@ -7,6 +7,13 @@ def track_eval_plot(matched_tracks): + """ + A four panel plot showing + * the target and predicted track on a map + * the track error in km + * the pressure at the cyclone core for target and prediction + * the maximum wind speed near the cyclone center + """ fig, axs = plt.subplots(2, 2, sharex=True, figsize=(10, 6)) fig.delaxes(axs[0, 0]) axs[0, 0] = fig.add_subplot(2, 2, 1, projection=ccrs.PlateCarree()) @@ -24,6 +31,9 @@ def track_eval_plot(matched_tracks): def bounding_box(matched_tracks, pad=2): + """ + Compute a lon/lat box containing the matched cyclone tracks. + """ all_lons = pd.concat([matched_tracks["target"]["lon"], matched_tracks["prediction"]["lon"]]) all_lats = pd.concat([matched_tracks["target"]["lat"], matched_tracks["prediction"]["lat"]]) lon_min = all_lons.min() - pad @@ -35,6 +45,12 @@ def bounding_box(matched_tracks, pad=2): def track_snapshots(matched_tracks, datasets, skip=5): + """ + A plot with two rows showing the spatial distribution of windspeeds + in prediction and target, with crosses marking the cyclone centers found + by the tracker. The time difference between snapshots is controlled by + skip. + """ bbox = bounding_box(matched_tracks) all_steps = matched_tracks["target"].index.union(matched_tracks["prediction"].index) selsteps = np.arange(0, len(all_steps), skip)