Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

8. From Zarr to PyTorch Training Loop

Learning objectives

By the end of this notebook you will be able to:

  • Build a custom PyTorch Dataset that reads lazily from a Zarr store

  • Wrap it in a DataLoader and connect chunking choices to real throughput

  • Construct a minimal CNN and run a short training loop

  • Denormalise model predictions back to physical units

  • Visualise a prediction against ground truth


The payoff

Every previous notebook prepared the data. This one consumes it.

The goal is not to train a good model. The goal is to show that the Zarr store produced by earthkit’s preprocessing pipeline plugs directly into a standard PyTorch training loop with minimal glue code.

The entire journey:

GRIB (ERA5)  →  earthkit-data  →  earthkit-geo  →  earthkit-transforms
     →  earthkit-meteo  →  Zarr store  →  xarray Dataset  →  PyTorch DataLoader  →  Training

Setup

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import xarray as xr
import numpy as np
import json
import os
import matplotlib.pyplot as plt

print(f"PyTorch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Prefer the multi-source store from notebook 7; fall back to era5.zarr from notebook 2
ZARR_PATH = "data/era5_multi_source.zarr" if os.path.exists("data/era5_multi_source.zarr") \
            else "data/era5.zarr"
STATS_PATH = "data/norm_stats.json"

print(f"Using Zarr store: {ZARR_PATH}")

Load the Zarr store

zarr_ds = xr.open_dataset(ZARR_PATH, engine="zarr")
print(zarr_ds)
# Inspect what we have to work with
vars_available = list(zarr_ds.data_vars)
print("Variables:", vars_available)

# Select the variable(s) to use — take all normalised fields
# With the multi-source store: 2t_norm, msl_norm, wspd_norm
# With era5.zarr fallback:    2t, msl (we normalise them below)
if "2t_norm" in vars_available:
    input_vars = [v for v in vars_available if v.endswith("_norm")]
else:
    # Normalise on the fly if not pre-normalised
    for v in vars_available:
        arr = zarr_ds[v].values.astype(np.float32)
        zarr_ds[v + "_norm"] = xr.DataArray(
            ((arr - arr.mean()) / arr.std()),
            dims=zarr_ds[v].dims,
            coords=zarr_ds[v].coords,
        )
    input_vars = [v + "_norm" for v in vars_available]

print("Input variables for training:", input_vars)

Build a temporal dataset with multiple timesteps

For a next-timestep prediction task we need a multi-timestep Zarr store. If era5_temporal.zarr from notebook 5 is available, use it. Otherwise we construct a synthetic multi-timestep array here.

if os.path.exists("data/era5_temporal.zarr"):
    temporal_ds = xr.open_dataset("data/era5_temporal.zarr", engine="zarr")
    data_var = "2t"
    raw = temporal_ds[data_var].values.astype(np.float32)
    arr = (raw - raw.mean()) / raw.std()
    H, W = arr.shape[1], arr.shape[2]
    print(f"Loaded era5_temporal.zarr: {arr.shape}")
else:
    # Synthetic fallback: 180 daily timesteps, 20 x 30 spatial grid
    T, H, W = 180, 20, 30
    t = np.arange(T)
    arr = (
        np.sin(2 * np.pi * t / 365)[:, None, None]
        + np.random.normal(0, 0.3, (T, H, W))
    ).astype(np.float32)
    print(f"Synthetic dataset: {arr.shape}")

Custom PyTorch Dataset

The Dataset class is the glue between a Zarr store and a DataLoader. It needs two methods:

  • __len__ — how many samples exist

  • __getitem__ — return sample ii as a pair of tensors (input, target)

Here one sample = the full spatial field at timestep tt (input) and timestep t+1t+1 (target) — a next-step prediction task.

class ERA5Dataset(Dataset):
    """Single-variable next-timestep prediction dataset backed by a numpy array.
    
    In production, replace the numpy array with lazy xarray indexing:
        sample = xr.open_dataset(zarr_path, engine="zarr")[var].isel(time=idx).values
    This reads only the necessary Zarr chunks.
    """

    def __init__(self, data: np.ndarray):
        # data shape: (T, H, W), already normalised
        self.data = torch.from_numpy(data)   # (T, H, W)

    def __len__(self):
        # We can form T-1 input/target pairs
        return len(self.data) - 1

    def __getitem__(self, idx):
        x = self.data[idx].unsqueeze(0)       # (1, H, W) — add channel dim
        y = self.data[idx + 1].unsqueeze(0)   # (1, H, W)
        return x, y


dataset = ERA5Dataset(arr)
print(f"Dataset length: {len(dataset)} samples")
x0, y0 = dataset[0]
print(f"Sample shapes: x={x0.shape}, y={y0.shape}")

DataLoader

The DataLoader batches samples and optionally prefetches them in background workers. The num_workers parameter controls parallelism — more workers help when I/O is the bottleneck (large Zarr stores on disk or object storage).

# Split: 80% train, 20% validation
n_train = int(0.8 * len(dataset))
train_ds = torch.utils.data.Subset(dataset, range(n_train))
val_ds   = torch.utils.data.Subset(dataset, range(n_train, len(dataset)))

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=8, shuffle=False, num_workers=0)

print(f"Train batches: {len(train_loader)}  |  Val batches: {len(val_loader)}")

# Confirm a batch arrives with the right shape
xb, yb = next(iter(train_loader))
print(f"Batch shapes: x={xb.shape}, y={yb.shape}")

Chunking reminder

The num_workers argument launches separate processes that each read Zarr chunks in parallel. If your Zarr store is time-chunked (one field per chunk) and your batch size equals the number of workers, each worker reads exactly one chunk — zero wasted I/O. This is exactly why the chunking choices from notebook 6 matter here.

Minimal model

We define a small convolutional encoder-decoder: two conv layers in, two conv layers out, with residual skip connection. The architecture is deliberately simple — the pipeline is the point, not the model.

class SmallConvNet(nn.Module):
    """Minimal next-timestep prediction CNN."""

    def __init__(self, in_channels: int = 1):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, in_channels, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return self.decoder(self.encoder(x)) + x  # residual skip


model = SmallConvNet(in_channels=1).to(device)
params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {params:,}")

# Sanity-check the forward pass
with torch.no_grad():
    test_out = model(xb.to(device))
print(f"Forward pass output shape: {test_out.shape}")

Training loop

N_EPOCHS = 5
lr = 1e-3

optimiser = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()

train_losses, val_losses = [], []

for epoch in range(N_EPOCHS):
    # --- Training ---
    model.train()
    running_loss = 0.0
    for x_batch, y_batch in train_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        optimiser.zero_grad()
        pred = model(x_batch)
        loss = criterion(pred, y_batch)
        loss.backward()
        optimiser.step()

        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)
    train_losses.append(train_loss)

    # --- Validation ---
    model.eval()
    running_val = 0.0
    with torch.no_grad():
        for x_batch, y_batch in val_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            pred = model(x_batch)
            running_val += criterion(pred, y_batch).item()

    val_loss = running_val / len(val_loader)
    val_losses.append(val_loss)

    print(f"Epoch {epoch+1}/{N_EPOCHS}  train_loss={train_loss:.4f}  val_loss={val_loss:.4f}")
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(range(1, N_EPOCHS+1), train_losses, label="Train")
ax.plot(range(1, N_EPOCHS+1), val_losses,   label="Validation")
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE loss")
ax.set_title("Training curve")
ax.legend()
plt.tight_layout()
plt.show()

Denormalisation

The model predicts in normalised space. To interpret predictions in physical units we apply the inverse transform using the statistics stored in notebook 3.

if os.path.exists(STATS_PATH):
    with open(STATS_PATH) as f:
        norm_stats = json.load(f)
    t2m_mu    = norm_stats["2t"]["mean"]
    t2m_sigma = norm_stats["2t"]["std"]
    print(f"Loaded normalisation stats: mean={t2m_mu:.2f} K, std={t2m_sigma:.2f} K")
else:
    # Fallback: use stats from the array itself
    t2m_mu    = float(arr.mean())
    t2m_sigma = float(arr.std())
    print(f"No stats file found — using in-sample stats: mean={t2m_mu:.4f}, std={t2m_sigma:.4f}")

def denormalise(normalised_arr, mu, sigma):
    """Undo z-score normalisation."""
    return normalised_arr * sigma + mu

Visualise prediction vs ground truth

# Pick a sample from the validation set
sample_idx = 0
x_sample, y_true = val_ds[sample_idx]

model.eval()
with torch.no_grad():
    y_pred = model(x_sample.unsqueeze(0).to(device)).squeeze().cpu().numpy()

y_true_np = y_true.squeeze().numpy()

# Denormalise to physical units (K)
y_true_k = denormalise(y_true_np, t2m_mu, t2m_sigma)
y_pred_k = denormalise(y_pred,    t2m_mu, t2m_sigma)
error     = y_pred_k - y_true_k

vmin = min(y_true_k.min(), y_pred_k.min())
vmax = max(y_true_k.max(), y_pred_k.max())

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

im0 = axes[0].imshow(y_true_k, cmap="RdBu_r", vmin=vmin, vmax=vmax, origin="upper")
axes[0].set_title("Ground truth (K)")
plt.colorbar(im0, ax=axes[0])

im1 = axes[1].imshow(y_pred_k, cmap="RdBu_r", vmin=vmin, vmax=vmax, origin="upper")
axes[1].set_title("Model prediction (K)")
plt.colorbar(im1, ax=axes[1])

im2 = axes[2].imshow(error, cmap="bwr", origin="upper")
axes[2].set_title("Error: prediction - truth (K)")
plt.colorbar(im2, ax=axes[2])

plt.suptitle("Next-timestep prediction after 5 epochs", fontsize=12)
plt.tight_layout()
plt.show()

print(f"MAE: {np.abs(error).mean():.4f} K")
print(f"RMSE: {np.sqrt((error**2).mean()):.4f} K")

Expected result: After only 5 epochs on a tiny synthetic/sample dataset the model has little skill. The prediction will largely follow the spatial mean — the error map will show the model struggling with fine structure. This is expected and intentional. The purpose is to demonstrate that the pipeline from Zarr to training loop is complete and working.

Where to go next

Scale up the data

  • Replace the sample data with real ERA5 via CDS (from_source("cds", ...))

  • Use multi-year training sets; leverage the chunking strategies from notebook 6

Scale up the training

# Distributed Data Parallel (PyTorch native)
torch.distributed.init_process_group(backend="nccl")
model = nn.parallel.DistributedDataParallel(model)

# Lightning / Fabric (minimal boilerplate)
from lightning import Fabric
fabric = Fabric(accelerator="gpu", devices=4, strategy="ddp")
fabric.launch()
model, optimiser = fabric.setup(model, optimiser)

Use a real architecture

  • GraphCast — graph neural network, uses HEALPix grid from notebook 4

  • Pangu-Weather — transformer, input is a fixed-size lat-lon tensor

  • FourCastNet — vision transformer (ViT), regular lat-lon, multi-level inputs

  • AIFS — ECMWF’s own GNN model, trained on ERA5

Close the loop with earthkit.plots

At inference time, denormalise outputs and pass them back to earthkit.plots for publication-quality maps with proper metadata, coastlines, and colour scales.


Summary

You have closed the full loop:

  1. NB01 — Loaded ERA5 with from_source()

  2. NB02 — Wrote to Zarr

  3. NB03 — Converted units, normalised, saved statistics

  4. NB04 — Regridded to a common grid

  5. NB05 — Computed temporal features and anomalies

  6. NB06 — Chose chunk layout for your access pattern

  7. NB07 — Combined multiple sources into one pipeline

  8. NB08 — Fed the Zarr store into a PyTorch training loop and visualised predictions

earhtkit provides the preprocessing layer. Everything between raw GRIB and loss.backward() is handled by the same unified pipeline regardless of where the data came from.


Activity

  1. Modify ERA5Dataset.__getitem__ to return a 3-timestep input window data[idx:idx+3] instead of a single timestep. Adjust the model input channels accordingly.

  2. Replace the synthetic data with data/era5_temporal.zarr from notebook 5. Does the loss converge differently?

  3. Save the trained model with torch.save(model.state_dict(), "data/model.pt") and write a short inference cell that loads it and runs prediction on a new sample.

# Save model
torch.save(model.state_dict(), "data/model.pt")

# Load and infer
model_loaded = SmallConvNet(in_channels=1)
model_loaded.load_state_dict(torch.load("data/model.pt"))
model_loaded.eval()