Learning objectives
By the end of this notebook you will be able to:
Explain how Zarr chunk layout affects DataLoader read performance
Identify the access pattern your training loop uses
Write the same data with three different chunking strategies
Benchmark read times for common ML access patterns
Rechunk an existing Zarr store
Chunking is the biggest performance lever before you touch model code¶
A DataLoader reads batches from your dataset. Each batch triggers one or more Zarr chunk reads. If your chunks do not align with your access pattern, a single batch may force the DataLoader to read far more data than it needs.
Example:
You train on single timesteps: one sample = the full spatial field at time
Your Zarr store is chunked with
time=-1(one chunk for the entire time series per grid point)Reading one timestep now forces a read of the entire time axis — 10-100× more I/O than needed
Three canonical strategies¶
| Strategy | Chunk shape | Best for |
|---|---|---|
| Time-chunked | (1, H, W) | Spatial ML — one sample = one full field |
| Space-chunked | (T, 1, 1) | Time-series forecasting at point locations |
| Balanced | (t, h, w) | Mixed access, general-purpose |
There is no universally best chunking. Match the chunk layout to the access pattern of your training loop.
Setup¶
import earthkit.data as ekd
import xarray as xr
import numpy as np
import zarr
import time
import os
os.makedirs("data", exist_ok=True)
print("zarr version:", zarr.__version__)Get dataset¶
To showcase the impact of chunking we will use a daily ERA5 over Europe 2 meter temperature sample.
ds = ekd.from_source("sample", "era5_europe_20200101_20201231_2t_1deg.grib")
xr_daily = ds.to_xarray(add_earthkit_attrs = False)
xr_dailyH = len(xr_daily.latitude)
W = len(xr_daily.longitude)
T = len(xr_daily.forecast_reference_time)Write three chunked variants¶
chunking_strategies = {
"time_chunked": {"forecast_reference_time": 1, "latitude": H, "longitude": W}, # one full field per chunk
"space_chunked": {"forecast_reference_time": T, "latitude": 1, "longitude": 1}, # one full time-series per chunk
"balanced": {"forecast_reference_time": 10, "latitude": 15, "longitude": 23}, # ~equal chunks
}
for name, chunks in chunking_strategies.items():
path = f"data/era5_{name}.zarr"
xr_daily.chunk(chunks).to_zarr(path, mode="w")
store = zarr.open(path, mode="r")
arr = store["2t"]
print(f"{name:20s}: chunks={arr.chunks} n_chunks={np.prod(arr.nchunks)}")Benchmark read times¶
We measure three access patterns representative of different ML use cases:
Single timestep — the most common batch element in spatial ML
Single-location time series — used in point-forecast and station-prediction models
Spatial window — a crop of the field, used in local U-Net training
N_REPEATS = 20 # number of reads to average
def bench_read(path, access_fn, n=N_REPEATS):
"""Return mean and std of read latency in milliseconds."""
ds_z = xr.open_dataset(path, engine="zarr")
times = []
for _ in range(n):
t0 = time.perf_counter()
_ = access_fn(ds_z).compute() # .compute() forces the actual read
times.append((time.perf_counter() - t0) * 1000)
ds_z.close()
return np.mean(times), np.std(times)
# Access functions
def read_single_timestep(ds):
idx = np.random.randint(0, T)
return ds["2t"].isel(forecast_reference_time=idx)
def read_location_series(ds):
li = np.random.randint(0, H)
lo = np.random.randint(0, W)
return ds["2t"].isel(latitude=li, longitude=lo)
def read_spatial_window(ds):
li = np.random.randint(0, H - 10)
lo = np.random.randint(0, W - 10)
ti = np.random.randint(0, T)
return ds["2t"].isel(forecast_reference_time=ti, latitude=slice(li, li+10), longitude=slice(lo, lo+10))
access_patterns = {
"Single timestep": read_single_timestep,
"Location time series": read_location_series,
"Spatial window (10x10)": read_spatial_window,
}
results = {}
for strategy in chunking_strategies:
path = f"data/era5_{strategy}.zarr"
results[strategy] = {}
for pattern_name, fn in access_patterns.items():
mean_ms, std_ms = bench_read(path, fn, n = N_REPEATS * 2)
results[strategy][pattern_name] = (mean_ms, std_ms)
print(f" {strategy:20s} | {pattern_name:25s} | {mean_ms:6.1f} ± {std_ms:.1f} ms")# Visualise benchmark results
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
strategy_names = list(chunking_strategies.keys())
colours = ["#2196F3", "#FF9800", "#4CAF50"]
for ax, pattern in zip(axes, access_patterns):
means = [results[s][pattern][0] for s in strategy_names]
stds = [results[s][pattern][1] for s in strategy_names]
bars = ax.bar(strategy_names, means, yerr=stds, color=colours, capsize=4)
ax.set_title(pattern)
ax.set_ylabel("Read latency (ms)")
ax.set_xticklabels([s.replace("_", "\n") for s in strategy_names], fontsize=9)
ax.set_ylim(bottom=0)
plt.suptitle("Zarr read latency by chunking strategy and access pattern", fontsize=12)
plt.tight_layout()
plt.show()Discussion¶
Interpret your results:
time_chunked should be fastest for single-timestep reads: the entire field is in one chunk
space_chunked should be fastest for location time series: the full time series is in one chunk
balanced trades off both, and is rarely optimal for either but acceptable for both
The benchmark results on synthetic data may show smaller differences than production datasets (where I/O latency dominates). With multi-GB Zarr stores on object storage (S3, GCS), the wrong chunking strategy can make training 10-100× slower.
Rechunking an existing store¶
# Rechunking: open an existing store and write it with new chunks
# This is done via xarray — load lazily, rechunk in memory, write back
source_path = "data/era5_space_chunked.zarr"
dest_path = "data/era5_rechunked.zarr"
ds_old = xr.open_dataset(source_path, engine="zarr")
new_chunks = {"forecast_reference_time": 1, "latitude": None, "longitude": None}
ds_old.chunk(new_chunks).compute().to_zarr(dest_path, mode="w")
print(f"Rechunked store written to {dest_path}")
print(f"New chunk layout: {new_chunks}")Chunking recommendations summary¶
| Training pattern | Recommended chunking | Rationale |
|---|---|---|
| Global/regional field forecast | (1, H, W) | One batch read = one chunk read |
| Station/point forecast | (T, 1, 1) | Full time series in one chunk |
| Patch-based CNN | (1, patch_h, patch_w) | Match patch size to chunk size |
| Transformer over time | (window, H, W) | Match temporal window to chunk |
| Multi-worker DataLoader | Prefer larger chunks | Reduces per-read metadata overhead |
Summary¶
You have:
Written the same dataset with time-chunked, space-chunked, and balanced strategies
Benchmarked read latency for three representative ML access patterns
Rechunked an existing Zarr store
Activity
Add a fourth chunking strategy:
{"forecast_reference_time": 30, "latitude": H, "longitude": W}(monthly spatial chunks). Where does it rank in the benchmark?Modify
N_REPEATS = 100and rerun. Do the rankings change? Why might variance be high with low repeat counts?Design the ideal chunking for a model that reads 7-day windows of the full spatial field at training time. Write that variant and benchmark it.
# Your chunking strategy here window_chunks = {"forecast_reference_time": 7, "latitude": H, "longitude": W}