Learning objectives
By the end of this notebook you will be able to:
Build a custom PyTorch
Datasetthat reads lazily from a Zarr storeWrap it in a
DataLoaderand connect chunking choices to real throughputConstruct a minimal CNN and run a short training loop
Denormalise model predictions back to physical units
Visualise a prediction against ground truth
The payoff¶
Every previous notebook prepared the data. This one consumes it.
The goal is not to train a good model. The goal is to show that the Zarr store produced by earthkit’s preprocessing pipeline plugs directly into a standard PyTorch training loop with minimal glue code.
The entire journey:
GRIB (ERA5) → earthkit-data → earthkit-geo → earthkit-transforms
→ earthkit-meteo → Zarr store → xarray Dataset → PyTorch DataLoader → TrainingSetup¶
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import xarray as xr
import numpy as np
import json
import os
import matplotlib.pyplot as plt
print(f"PyTorch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Prefer the multi-source store from notebook 7; fall back to era5.zarr from notebook 2
ZARR_PATH = "data/era5_multi_source.zarr" if os.path.exists("data/era5_multi_source.zarr") \
else "data/era5.zarr"
STATS_PATH = "data/norm_stats.json"
print(f"Using Zarr store: {ZARR_PATH}")Load the Zarr store¶
zarr_ds = xr.open_dataset(ZARR_PATH, engine="zarr")
print(zarr_ds)# Inspect what we have to work with
vars_available = list(zarr_ds.data_vars)
print("Variables:", vars_available)
# Select the variable(s) to use — take all normalised fields
# With the multi-source store: 2t_norm, msl_norm, wspd_norm
# With era5.zarr fallback: 2t, msl (we normalise them below)
if "2t_norm" in vars_available:
input_vars = [v for v in vars_available if v.endswith("_norm")]
else:
# Normalise on the fly if not pre-normalised
for v in vars_available:
arr = zarr_ds[v].values.astype(np.float32)
zarr_ds[v + "_norm"] = xr.DataArray(
((arr - arr.mean()) / arr.std()),
dims=zarr_ds[v].dims,
coords=zarr_ds[v].coords,
)
input_vars = [v + "_norm" for v in vars_available]
print("Input variables for training:", input_vars)Build a temporal dataset with multiple timesteps¶
For a next-timestep prediction task we need a multi-timestep Zarr store. If era5_temporal.zarr from notebook 5 is available, use it. Otherwise we construct a synthetic multi-timestep array here.
if os.path.exists("data/era5_temporal.zarr"):
temporal_ds = xr.open_dataset("data/era5_temporal.zarr", engine="zarr")
data_var = "2t"
raw = temporal_ds[data_var].values.astype(np.float32)
arr = (raw - raw.mean()) / raw.std()
H, W = arr.shape[1], arr.shape[2]
print(f"Loaded era5_temporal.zarr: {arr.shape}")
else:
# Synthetic fallback: 180 daily timesteps, 20 x 30 spatial grid
T, H, W = 180, 20, 30
t = np.arange(T)
arr = (
np.sin(2 * np.pi * t / 365)[:, None, None]
+ np.random.normal(0, 0.3, (T, H, W))
).astype(np.float32)
print(f"Synthetic dataset: {arr.shape}")Custom PyTorch Dataset¶
The Dataset class is the glue between a Zarr store and a DataLoader. It needs two methods:
__len__— how many samples exist__getitem__— return sample as a pair of tensors(input, target)
Here one sample = the full spatial field at timestep (input) and timestep (target) — a next-step prediction task.
class ERA5Dataset(Dataset):
"""Single-variable next-timestep prediction dataset backed by a numpy array.
In production, replace the numpy array with lazy xarray indexing:
sample = xr.open_dataset(zarr_path, engine="zarr")[var].isel(time=idx).values
This reads only the necessary Zarr chunks.
"""
def __init__(self, data: np.ndarray):
# data shape: (T, H, W), already normalised
self.data = torch.from_numpy(data) # (T, H, W)
def __len__(self):
# We can form T-1 input/target pairs
return len(self.data) - 1
def __getitem__(self, idx):
x = self.data[idx].unsqueeze(0) # (1, H, W) — add channel dim
y = self.data[idx + 1].unsqueeze(0) # (1, H, W)
return x, y
dataset = ERA5Dataset(arr)
print(f"Dataset length: {len(dataset)} samples")
x0, y0 = dataset[0]
print(f"Sample shapes: x={x0.shape}, y={y0.shape}")DataLoader¶
The DataLoader batches samples and optionally prefetches them in background workers. The num_workers parameter controls parallelism — more workers help when I/O is the bottleneck (large Zarr stores on disk or object storage).
# Split: 80% train, 20% validation
n_train = int(0.8 * len(dataset))
train_ds = torch.utils.data.Subset(dataset, range(n_train))
val_ds = torch.utils.data.Subset(dataset, range(n_train, len(dataset)))
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=0)
print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")
# Confirm a batch arrives with the right shape
xb, yb = next(iter(train_loader))
print(f"Batch shapes: x={xb.shape}, y={yb.shape}")Chunking reminder
The
num_workersargument launches separate processes that each read Zarr chunks in parallel. If your Zarr store is time-chunked (one field per chunk) and your batch size equals the number of workers, each worker reads exactly one chunk — zero wasted I/O. This is exactly why the chunking choices from notebook 6 matter here.
Minimal model¶
We define a small convolutional encoder-decoder: two conv layers in, two conv layers out, with residual skip connection. The architecture is deliberately simple — the pipeline is the point, not the model.
class SmallConvNet(nn.Module):
"""Minimal next-timestep prediction CNN."""
def __init__(self, in_channels: int = 1):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.Conv2d(32, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(16, in_channels, kernel_size=3, padding=1),
)
def forward(self, x):
return self.decoder(self.encoder(x)) + x # residual skip
model = SmallConvNet(in_channels=1).to(device)
params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {params:,}")
# Sanity-check the forward pass
with torch.no_grad():
test_out = model(xb.to(device))
print(f"Forward pass output shape: {test_out.shape}")Training loop¶
N_EPOCHS = 5
lr = 1e-3
optimiser = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
train_losses, val_losses = [], []
for epoch in range(N_EPOCHS):
# --- Training ---
model.train()
running_loss = 0.0
for x_batch, y_batch in train_loader:
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
optimiser.zero_grad()
pred = model(x_batch)
loss = criterion(pred, y_batch)
loss.backward()
optimiser.step()
running_loss += loss.item()
train_loss = running_loss / len(train_loader)
train_losses.append(train_loss)
# --- Validation ---
model.eval()
running_val = 0.0
with torch.no_grad():
for x_batch, y_batch in val_loader:
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
pred = model(x_batch)
running_val += criterion(pred, y_batch).item()
val_loss = running_val / len(val_loader)
val_losses.append(val_loss)
print(f"Epoch {epoch+1}/{N_EPOCHS} train_loss={train_loss:.4f} val_loss={val_loss:.4f}")fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(range(1, N_EPOCHS+1), train_losses, label="Train")
ax.plot(range(1, N_EPOCHS+1), val_losses, label="Validation")
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE loss")
ax.set_title("Training curve")
ax.legend()
plt.tight_layout()
plt.show()Denormalisation¶
The model predicts in normalised space. To interpret predictions in physical units we apply the inverse transform using the statistics stored in notebook 3.
if os.path.exists(STATS_PATH):
with open(STATS_PATH) as f:
norm_stats = json.load(f)
t2m_mu = norm_stats["2t"]["mean"]
t2m_sigma = norm_stats["2t"]["std"]
print(f"Loaded normalisation stats: mean={t2m_mu:.2f} K, std={t2m_sigma:.2f} K")
else:
# Fallback: use stats from the array itself
t2m_mu = float(arr.mean())
t2m_sigma = float(arr.std())
print(f"No stats file found — using in-sample stats: mean={t2m_mu:.4f}, std={t2m_sigma:.4f}")
def denormalise(normalised_arr, mu, sigma):
"""Undo z-score normalisation."""
return normalised_arr * sigma + muVisualise prediction vs ground truth¶
# Pick a sample from the validation set
sample_idx = 0
x_sample, y_true = val_ds[sample_idx]
model.eval()
with torch.no_grad():
y_pred = model(x_sample.unsqueeze(0).to(device)).squeeze().cpu().numpy()
y_true_np = y_true.squeeze().numpy()
# Denormalise to physical units (K)
y_true_k = denormalise(y_true_np, t2m_mu, t2m_sigma)
y_pred_k = denormalise(y_pred, t2m_mu, t2m_sigma)
error = y_pred_k - y_true_k
vmin = min(y_true_k.min(), y_pred_k.min())
vmax = max(y_true_k.max(), y_pred_k.max())
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
im0 = axes[0].imshow(y_true_k, cmap="RdBu_r", vmin=vmin, vmax=vmax, origin="upper")
axes[0].set_title("Ground truth (K)")
plt.colorbar(im0, ax=axes[0])
im1 = axes[1].imshow(y_pred_k, cmap="RdBu_r", vmin=vmin, vmax=vmax, origin="upper")
axes[1].set_title("Model prediction (K)")
plt.colorbar(im1, ax=axes[1])
im2 = axes[2].imshow(error, cmap="bwr", origin="upper")
axes[2].set_title("Error: prediction - truth (K)")
plt.colorbar(im2, ax=axes[2])
plt.suptitle("Next-timestep prediction after 5 epochs", fontsize=12)
plt.tight_layout()
plt.show()
print(f"MAE: {np.abs(error).mean():.4f} K")
print(f"RMSE: {np.sqrt((error**2).mean()):.4f} K")Expected result: After only 5 epochs on a tiny synthetic/sample dataset the model has little skill. The prediction will largely follow the spatial mean — the error map will show the model struggling with fine structure. This is expected and intentional. The purpose is to demonstrate that the pipeline from Zarr to training loop is complete and working.
Where to go next¶
Scale up the data¶
Replace the sample data with real ERA5 via CDS (
from_source("cds", ...))Use multi-year training sets; leverage the chunking strategies from notebook 6
Scale up the training¶
# Distributed Data Parallel (PyTorch native)
torch.distributed.init_process_group(backend="nccl")
model = nn.parallel.DistributedDataParallel(model)
# Lightning / Fabric (minimal boilerplate)
from lightning import Fabric
fabric = Fabric(accelerator="gpu", devices=4, strategy="ddp")
fabric.launch()
model, optimiser = fabric.setup(model, optimiser)Use a real architecture¶
GraphCast — graph neural network, uses HEALPix grid from notebook 4
Pangu-Weather — transformer, input is a fixed-size lat-lon tensor
FourCastNet — vision transformer (ViT), regular lat-lon, multi-level inputs
AIFS — ECMWF’s own GNN model, trained on ERA5
Close the loop with earthkit.plots¶
At inference time, denormalise outputs and pass them back to earthkit.plots for publication-quality maps with proper metadata, coastlines, and colour scales.
Summary¶
You have closed the full loop:
NB01 — Loaded ERA5 with
from_source()NB02 — Wrote to Zarr
NB03 — Converted units, normalised, saved statistics
NB04 — Regridded to a common grid
NB05 — Computed temporal features and anomalies
NB06 — Chose chunk layout for your access pattern
NB07 — Combined multiple sources into one pipeline
NB08 — Fed the Zarr store into a PyTorch training loop and visualised predictions
earhtkit provides the preprocessing layer. Everything between raw GRIB and loss.backward() is handled by the same unified pipeline regardless of where the data came from.
Activity
Modify
ERA5Dataset.__getitem__to return a 3-timestep input windowdata[idx:idx+3]instead of a single timestep. Adjust the model input channels accordingly.Replace the synthetic data with
data/era5_temporal.zarrfrom notebook 5. Does the loss converge differently?Save the trained model with
torch.save(model.state_dict(), "data/model.pt")and write a short inference cell that loads it and runs prediction on a new sample.# Save model torch.save(model.state_dict(), "data/model.pt") # Load and infer model_loaded = SmallConvNet(in_channels=1) model_loaded.load_state_dict(torch.load("data/model.pt")) model_loaded.eval()