Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

6. Chunking Strategies for ML

Learning objectives

By the end of this notebook you will be able to:

  • Explain how Zarr chunk layout affects DataLoader read performance

  • Identify the access pattern your training loop uses

  • Write the same data with three different chunking strategies

  • Benchmark read times for common ML access patterns

  • Rechunk an existing Zarr store


Chunking is the biggest performance lever before you touch model code

A DataLoader reads batches from your dataset. Each batch triggers one or more Zarr chunk reads. If your chunks do not align with your access pattern, a single batch may force the DataLoader to read far more data than it needs.

Example:

  • You train on single timesteps: one sample = the full spatial field at time tt

  • Your Zarr store is chunked with time=-1 (one chunk for the entire time series per grid point)

  • Reading one timestep now forces a read of the entire time axis — 10-100× more I/O than needed

Three canonical strategies

StrategyChunk shapeBest for
Time-chunked(1, H, W)Spatial ML — one sample = one full field
Space-chunked(T, 1, 1)Time-series forecasting at point locations
Balanced(t, h, w)Mixed access, general-purpose

There is no universally best chunking. Match the chunk layout to the access pattern of your training loop.

Setup

import earthkit.data as ekd

import xarray as xr
import numpy as np
import zarr
import time
import os

os.makedirs("data", exist_ok=True)
print("zarr version:", zarr.__version__)

Get dataset

To showcase the impact of chunking we will use a daily ERA5 over Europe 2 meter temperature sample.

ds = ekd.from_source("sample", "era5_europe_20200101_20201231_2t_1deg.grib")
xr_daily = ds.to_xarray(add_earthkit_attrs = False)
xr_daily
H = len(xr_daily.latitude)
W = len(xr_daily.longitude)
T = len(xr_daily.forecast_reference_time)

Write three chunked variants

chunking_strategies = {
    "time_chunked":    {"forecast_reference_time": 1,   "latitude": H,   "longitude": W},    # one full field per chunk
    "space_chunked":   {"forecast_reference_time": T,   "latitude": 1,   "longitude": 1},    # one full time-series per chunk
    "balanced":        {"forecast_reference_time": 10,  "latitude": 15,  "longitude": 23},   # ~equal chunks
}

for name, chunks in chunking_strategies.items():
    path = f"data/era5_{name}.zarr"
    xr_daily.chunk(chunks).to_zarr(path, mode="w")
    store = zarr.open(path, mode="r")
    arr = store["2t"]
    print(f"{name:20s}: chunks={arr.chunks}  n_chunks={np.prod(arr.nchunks)}")

Benchmark read times

We measure three access patterns representative of different ML use cases:

  1. Single timestep — the most common batch element in spatial ML

  2. Single-location time series — used in point-forecast and station-prediction models

  3. Spatial window — a crop of the field, used in local U-Net training

N_REPEATS = 20  # number of reads to average

def bench_read(path, access_fn, n=N_REPEATS):
    """Return mean and std of read latency in milliseconds."""
    ds_z = xr.open_dataset(path, engine="zarr")
    times = []
    for _ in range(n):
        t0 = time.perf_counter()
        _ = access_fn(ds_z).compute()  # .compute() forces the actual read
        times.append((time.perf_counter() - t0) * 1000)
    ds_z.close()
    return np.mean(times), np.std(times)

# Access functions
def read_single_timestep(ds):
    idx = np.random.randint(0, T)
    return ds["2t"].isel(forecast_reference_time=idx)

def read_location_series(ds):
    li = np.random.randint(0, H)
    lo = np.random.randint(0, W)
    return ds["2t"].isel(latitude=li, longitude=lo)

def read_spatial_window(ds):
    li = np.random.randint(0, H - 10)
    lo = np.random.randint(0, W - 10)
    ti = np.random.randint(0, T)
    return ds["2t"].isel(forecast_reference_time=ti, latitude=slice(li, li+10), longitude=slice(lo, lo+10))

access_patterns = {
    "Single timestep": read_single_timestep,
    "Location time series": read_location_series,
    "Spatial window (10x10)": read_spatial_window,
}

results = {}
for strategy in chunking_strategies:
    path = f"data/era5_{strategy}.zarr"
    results[strategy] = {}
    for pattern_name, fn in access_patterns.items():
        mean_ms, std_ms = bench_read(path, fn, n = N_REPEATS * 2)
        results[strategy][pattern_name] = (mean_ms, std_ms)
        print(f"  {strategy:20s} | {pattern_name:25s} | {mean_ms:6.1f} ± {std_ms:.1f} ms")
# Visualise benchmark results
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

strategy_names = list(chunking_strategies.keys())
colours = ["#2196F3", "#FF9800", "#4CAF50"]

for ax, pattern in zip(axes, access_patterns):
    means = [results[s][pattern][0] for s in strategy_names]
    stds  = [results[s][pattern][1] for s in strategy_names]
    bars = ax.bar(strategy_names, means, yerr=stds, color=colours, capsize=4)
    ax.set_title(pattern)
    ax.set_ylabel("Read latency (ms)")
    ax.set_xticklabels([s.replace("_", "\n") for s in strategy_names], fontsize=9)
    ax.set_ylim(bottom=0)

plt.suptitle("Zarr read latency by chunking strategy and access pattern", fontsize=12)
plt.tight_layout()
plt.show()

Discussion

Interpret your results:

  • time_chunked should be fastest for single-timestep reads: the entire field is in one chunk

  • space_chunked should be fastest for location time series: the full time series is in one chunk

  • balanced trades off both, and is rarely optimal for either but acceptable for both

The benchmark results on synthetic data may show smaller differences than production datasets (where I/O latency dominates). With multi-GB Zarr stores on object storage (S3, GCS), the wrong chunking strategy can make training 10-100× slower.


Rechunking an existing store

# Rechunking: open an existing store and write it with new chunks
# This is done via xarray — load lazily, rechunk in memory, write back

source_path = "data/era5_space_chunked.zarr"
dest_path   = "data/era5_rechunked.zarr"

ds_old = xr.open_dataset(source_path, engine="zarr")
new_chunks = {"forecast_reference_time": 1, "latitude": None, "longitude": None}

ds_old.chunk(new_chunks).compute().to_zarr(dest_path, mode="w")

print(f"Rechunked store written to {dest_path}")
print(f"New chunk layout: {new_chunks}")

Chunking recommendations summary

Training patternRecommended chunkingRationale
Global/regional field forecast(1, H, W)One batch read = one chunk read
Station/point forecast(T, 1, 1)Full time series in one chunk
Patch-based CNN(1, patch_h, patch_w)Match patch size to chunk size
Transformer over time(window, H, W)Match temporal window to chunk
Multi-worker DataLoaderPrefer larger chunksReduces per-read metadata overhead

Summary

You have:

  • Written the same dataset with time-chunked, space-chunked, and balanced strategies

  • Benchmarked read latency for three representative ML access patterns

  • Rechunked an existing Zarr store


Activity

  1. Add a fourth chunking strategy: {"forecast_reference_time": 30, "latitude": H, "longitude": W} (monthly spatial chunks). Where does it rank in the benchmark?

  2. Modify N_REPEATS = 100 and rerun. Do the rankings change? Why might variance be high with low repeat counts?

  3. Design the ideal chunking for a model that reads 7-day windows of the full spatial field at training time. Write that variant and benchmark it.

# Your chunking strategy here
window_chunks = {"forecast_reference_time": 7, "latitude": H, "longitude": W}