Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from functools import cached_property
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from cyclone_finder import (
Cyclone,
CycloneFinder,
cyclones_in_ds,
track2pandas,
track_cyclones,
wrap_lon,
)
from cyclone_plots import track_eval_plot, track_snapshots
from omegaconf import OmegaConf


class TcCaseStudy:
"""
Read the cyclone tracker settings, data paths and the target cyclone
from config, then find the matched tracks corresponding to that cyclones
in the prediction and target.
"""

def __init__(self, cfg: dict):
self.cfg = cfg
self.selected_storm = Cyclone(
wind=0,
pressure=0,
lon=cfg.selected_storm.lon,
lat=cfg.selected_storm.lat,
time=np.datetime64(cfg.selected_storm.time),
)
self.finder = CycloneFinder(
sigma=cfg.tracking_params.laplace_size,
th_laplace=cfg.tracking_params.laplace_threshold,
th_pressure=cfg.tracking_params.pressure_threshold,
th_wind=cfg.tracking_params.wind_threshold,
min_distance=cfg.tracking_params.peak_separation,
)
self.outpath = Path(cfg.outpath)

@cached_property
def datasets(self):
infiles = {
k: f"{self.cfg.inpath}{k}_{self.cfg.init_time}_{self.cfg.runid}_ERA5.nc"
for k in ("target", "prediction")
}
datasets = {
k: wrap_lon(xr.open_dataset(f)).sel(latitude=slice(self.cfg.latmin, self.cfg.latmax))
for k, f in infiles.items()
}
return datasets

@cached_property
def cyclones(self):
times = self.datasets["target"].valid_time.values
cyclones = {
k: [cyclones_in_ds(ds, self.finder, time=t) for t in times]
for k, ds in self.datasets.items()
}
return cyclones

@cached_property
def tracks(self):
tracks = {
k: track_cyclones(d, self.cfg.tracking_params.merge_distance)
for k, d in self.cyclones.items()
}
return tracks

@cached_property
def matched_tracks(self):
times = self.datasets["target"].valid_time.values
storm_index = np.argmin(np.abs(times - self.selected_storm.time))
matched_stroms = {
k: self.selected_storm.match(x[storm_index]) for k, x in self.cyclones.items()
}
matched_tracks = {
k: track2pandas(d.subset(matched_stroms[k])) for k, d in self.tracks.items()
}
return matched_tracks

def plot(self):
self.outpath.mkdir(exist_ok=True)
# evaluation plot
evalfile = f"{self.outpath}/{self.cfg.runid}_cyclone_{self.cfg.init_time}.png"
fig, axs = track_eval_plot(self.matched_tracks)
init_time = self.datasets["target"].forecast_reference_time.values
fig.suptitle(f"forecast initialized {init_time}")
plt.savefig(evalfile)

# example maps
snapshotfile = f"{self.outpath}/{self.cfg.runid}_cyclone_{self.cfg.init_time}_snapshots.png"
track_snapshots(self.matched_tracks, self.datasets)
plt.savefig(snapshotfile)


def main():
cfg = OmegaConf.load("TC_config.yml")
casestudy = TcCaseStudy(cfg)
casestudy.plot()


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
runid : "i0xr7z48"
init_time : "2023-10-07T00"
inpath : "/p/project1/weatherai/buschow1/wegen_export/cyclones/"
outpath : "./plots/"
latmin: -30 # TCs are only detected for latmin<=lat<=latmax
latmax: 30
selected_storm : # the storm you want to analyze
lon : 154.7
lat : 9.6
time : "2023-10-07T00:00"
tracking_params:
laplace_size : 2 # in units of gridboxes
laplace_threshold : 0 # should be >= 0 to fond low pressure systems
pressure_threshold : 103000 # in Pa
wind_threshold : 0 # in m/s
peak_separation : 5 # in units of gridboxes
merge_distance: 300 # in km
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from dataclasses import dataclass

import numpy as np
import pandas as pd
import xarray as xr
from scipy.cluster.hierarchy import DisjointSet
from scipy.ndimage import gaussian_laplace, maximum_filter
from skimage.feature import peak_local_max
from sklearn.metrics.pairwise import haversine_distances
from tqdm import tqdm


@dataclass(order=True, frozen=True)
class Cyclone:
wind: float
pressure: float
lon: float
lat: float
ID: str | None = None
time: np.datetime64 | None = None

def dist_to(self, other: "Cyclone") -> float:
r_earth = 6371.0
p1 = [np.deg2rad(deg) for deg in (self.lat, self.lon)]
p2 = [np.deg2rad(deg) for deg in (other.lat, other.lon)]
angle = haversine_distances(X=np.array(p1).reshape(1, -1), Y=np.array(p2).reshape(1, -1))
return r_earth * angle

def match(self, cyclones: list["Cyclone"], maxdist_km: float = 3000) -> "Cyclone":
"""
Select the closest from a set of other cyclones
"""
dists = [self.dist_to(other) for other in cyclones]
if min(dists) < maxdist_km:
return cyclones[np.argmin(dists)]
else:
return None


class CycloneFinder:
def __init__(
self,
sigma: float = 2,
th_laplace: float = 30,
th_pressure: float = 101000,
th_wind: float = 10,
min_distance: float = 5,
):
"""
Try finding cyclones with simple blob detection
plus some heuristic filter criteria
Attributes
----------
sigma: Gauss standard deviation. The zeros of the laplace filter
are at sqrt(2)*sigma distance from the center
th_laplace: minimum value of the filtered field
th_pressure: maxmimum pressure value
th_wind: minimum wind speed
min_distance: minimum distance between peaks in number of gridpoints
"""
self.sigma = sigma
self.th_laplace = th_laplace
self.th_pressure = th_pressure
self.th_wind = th_wind
self.min_distance = min_distance

def filter(self, image):
return gaussian_laplace(image, sigma=self.sigma)

def mask(self, pressure, windmax):
pressuremask = (pressure < self.th_pressure).values
windmask = windmax > self.th_wind
return pressuremask & windmask

def find_cyclones(self, pressure, wind, windmaxsize=5, timestamp=None) -> list["Cyclone"]:
# apply the LoG filter to pressure
filtered = self.filter(pressure)
# find candidate maxima
candidates = peak_local_max(
filtered, threshold_abs=self.th_laplace, min_distance=self.min_distance
)
# apply mask
windmax = maximum_filter(wind.values, size=windmaxsize)
mask = self.mask(pressure, windmax)[candidates[:, 0], candidates[:, 1]]
cyclones = candidates[mask, :]
res = [
Cyclone(
lon=pressure.longitude.values[y],
lat=pressure.latitude.values[x],
wind=windmax[x, y],
pressure=pressure.values[x, y],
time=timestamp,
)
for x, y in zip(cyclones[:, 0], cyclones[:, 1], strict=False)
]
return res


def track_cyclones(timesteps: list[list["Cyclone"]], merge_distance_km: float = 300) -> DisjointSet:
"""
Takes a list of lists of cyclones, each top level entry representing one timestep,
returns a DisjointSet where each entry represents a track.
"""
tracks = DisjointSet()
prev_step = []

for step in tqdm(timesteps):
# Add all storms from this timestep
for storm in step:
tracks.add(storm)

# Build all candidate matches (prev → curr)
candidates = []
for s_prev in prev_step:
for s_curr in step:
d = s_prev.dist_to(s_curr)
if d <= merge_distance_km:
candidates.append((d, s_prev, s_curr))

# Sort by distance (closest first)
candidates.sort(key=lambda x: x[0])

# Keep track of which storms have already been matched
used_prev = set()
used_curr = set()

# Greedy matching: closest pairs first
for _dist, s_prev, s_curr in candidates:
if s_prev not in used_prev and s_curr not in used_curr:
tracks.merge(s_prev, s_curr)
used_prev.add(s_prev)
used_curr.add(s_curr)

prev_step = step

return tracks


def track2pandas(track: list["Cyclone"]) -> pd.DataFrame:
return pd.DataFrame([storm.__dict__ for storm in track]).set_index("time").sort_index()


def cyclones_in_ds(ds: xr.Dataset, finder: "CycloneFinder", time: np.datetime64) -> list["Cyclone"]:
"""
Find cyclones in a dataset containing at least msl, u10, v10,
at a given timestep, using a given CycloneFinder.
"""
ds_t = ds.sel(valid_time=time)
msl = ds_t.msl
v = np.sqrt(ds_t.u10**2 + ds_t.v10**2)
return finder.find_cyclones(pressure=msl, wind=v, timestamp=time)


def track_error(track1: pd.DataFrame, track2: pd.DataFrame) -> pd.DataFrame:
"""
Given two tracks as pd.DataFrames, compute their distance in km.
At timesteps where one track is missing, the result is NaN.
"""
r_earth = 6371.0
coords = [np.deg2rad(x.loc[:, ["lat", "lon"]]) for x in track1.align(track2, join="inner")]
angle = haversine_distances(X=coords[0].values, Y=coords[1].values)
distance = pd.DataFrame({"distance": r_earth * np.diag(angle)}, index=coords[0].index)
all_idx = track1.index.union(track2.index)
distance = distance.reindex(all_idx)

return distance


def wrap_lon(ds: xr.Dataset) -> xr.Dataset:
"Convert longitude from 0...360 to -180...180"
ds["longitude"] = (ds["longitude"] + 180) % 360 - 180
ds = ds.sortby("longitude")
return ds
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from cyclone_finder import track_error


def track_eval_plot(matched_tracks):
"""
A four panel plot showing
* the target and predicted track on a map
* the track error in km
* the pressure at the cyclone core for target and prediction
* the maximum wind speed near the cyclone center
"""
fig, axs = plt.subplots(2, 2, sharex=True, figsize=(10, 6))
fig.delaxes(axs[0, 0])
axs[0, 0] = fig.add_subplot(2, 2, 1, projection=ccrs.PlateCarree())
axs[0, 0].coastlines()
axs[0, 0].set_title("storm tracks")
axs[0, 1].set_title("track error in km")
axs[1, 0].set_title("core pressure in Pa")
axs[1, 1].set_title("max wind speed in m/s")
track_error(*matched_tracks.values()).plot(ax=axs[0, 1])
for lab, track in matched_tracks.items():
track.plot(x="lon", y="lat", ax=axs[0, 0], label=lab)
track.plot(y="pressure", ax=axs[1, 0], label=lab)
track.plot(y="wind", ax=axs[1, 1], label=lab)
return fig, axs


def bounding_box(matched_tracks, pad=2):
"""
Compute a lon/lat box containing the matched cyclone tracks.
"""
all_lons = pd.concat([matched_tracks["target"]["lon"], matched_tracks["prediction"]["lon"]])
all_lats = pd.concat([matched_tracks["target"]["lat"], matched_tracks["prediction"]["lat"]])
lon_min = all_lons.min() - pad
lon_max = all_lons.max() + pad
lat_min = all_lats.min() - pad
lat_max = all_lats.max() + pad
bbox = (lon_min, lon_max, lat_min, lat_max)
return bbox


def track_snapshots(matched_tracks, datasets, skip=5):
"""
A plot with two rows showing the spatial distribution of windspeeds
in prediction and target, with crosses marking the cyclone centers found
by the tracker. The time difference between snapshots is controlled by
skip.
"""
bbox = bounding_box(matched_tracks)
all_steps = matched_tracks["target"].index.union(matched_tracks["prediction"].index)
selsteps = np.arange(0, len(all_steps), skip)
plotdat = xr.concat(datasets.values(), dim=datasets.keys()).isel(valid_time=selsteps)
plotdat = plotdat.sel(longitude=slice(bbox[0], bbox[1]), latitude=slice(bbox[2], bbox[3]))
speed = np.sqrt(plotdat.u10**2 + plotdat.v10**2)
p = speed.plot(
row="concat_dim", col="valid_time", subplot_kws=dict(projection=ccrs.PlateCarree())
)
for ax in p.axs.flatten():
ax.coastlines()
ax.set_extent(bbox)
for i, s in enumerate(all_steps[selsteps]):
leadtime = plotdat.forecast_period[i].values / np.timedelta64(1, "h")
p.axs[0, i].set_title(s)
p.axs[1, i].set_title(f"{leadtime}h forecast")
for j, tr in enumerate(matched_tracks.values()):
if s in tr.index:
tr.loc[[s]].plot.scatter(
x="lon", y="lat", ax=p.axs[j, i], color="tab:red", marker="x", s=100
)
return p
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ dependencies = [
"anemoi-datasets",
"weathergen-common",
"weathergen-evaluate",
"weathergen-readers-extra"
"weathergen-readers-extra",
"scikit-image>=0.26.0",
"scikit-learn>=1.8.0",
]


Expand Down
Loading