Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

06 — ML patches + augmentations

ML patches — augmentations + tiled inference

The companion to Patching: grid → process → stitch. Same patcher / Field / gz.patch_ops machinery, but now the per-chip work is the ML inference path: load → augment (training-time only) → model → stitch into a per-pixel classification map.

We demonstrate three slices of the package on a real Sentinel-2 scene:

StageModuleOps
Training-chip extraction (irregular sampling)geopatcherSpatialJitteredStride
Per-chip augmentationgz.augmentCompose([RandomFlip, RandomRotate90, BrightnessJitter, GaussianNoise])
Inference + carrier wrapgz.ModelOpModelOp(model, method="predict")
Vote-stitching of classification labelsgeopatcherSpatialHardVote(n_classes=3)

The “model” is a rule-based three-class classifier driven by NDVI + NDWI thresholds — water (0), bare/scrub (1), vegetation (2). It plays the role of any real sklearn / PyTorch estimator: anything with a predict method (or __call__) plugs into ModelOp interchangeably.

import geopatcher as gp
import geotoolz as gz
import matplotlib.pyplot as plt
import numpy as np
import planetary_computer
import pystac_client
import rioxarray
from geopatcher.fields import RasterField
from georeader.geotensor import GeoTensor
from geotoolz.patch_ops import ApplyToChips, GridSampler, Stitch
from matplotlib.colors import ListedColormap
from rasterio.enums import Resampling

1. Load one Sentinel-2 scene

Lake Tahoe, same scene as the companion patching notebook. BGRN (4 bands) — we want green + NIR for NDWI and red + NIR for NDVI.

BBOX = (-120.10, 38.92, -119.93, 39.27)

catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)
items = sorted(
    catalog.search(
        collections=["sentinel-2-l2a"],
        bbox=BBOX,
        datetime="2024-06-01/2024-07-15",
        query={
            "eo:cloud_cover": {"lt": 5},
            "s2:mgrs_tile": {"eq": "10SGJ"},
        },
    ).items(),
    key=lambda x: x.properties["eo:cloud_cover"],
)
item = items[0]
print(f"using {item.id}")
using S2B_MSIL2A_20240614T183919_R070_T10SGJ_20240615T003207
def _load(key, *, ref=None, resampling=Resampling.bilinear):
    da = rioxarray.open_rasterio(item.assets[key].href, masked=False)
    da = da.squeeze("band", drop=True).rio.clip_box(*BBOX, crs="EPSG:4326")
    if ref is not None:
        da = da.rio.reproject_match(ref, resampling=resampling)
    return da


red = _load("B04")
green = _load("B03", ref=red)
blue = _load("B02", ref=red)
nir = _load("B08", ref=red)

# Reflectance up front so augmentations and the model see physical units.
dn_stack = np.stack([blue.values, green.values, red.values, nir.values], axis=0).astype(
    "uint16"
)
refl = gz.radiometry.DNToReflectance(scale=1e-4)(
    GeoTensor(
        values=dn_stack,
        transform=red.rio.transform(),
        crs=red.rio.crs,
        fill_value_default=0,
    )
)
print(f"reflectance carrier: shape={refl.shape}  dtype={refl.dtype}")
reflectance carrier: shape=(4, 3935, 1599)  dtype=float64

2. Training-chip sampling — SpatialJitteredStride

For a training dataset we don’t want a regular grid (every chip lines up the same way and overfits to position); we want jittered anchors so the model sees translated variants of similar terrain. SpatialJitteredStride(step, jitter) perturbs each regular-grid anchor by a uniform random offset bounded by jitter pixels.

Compare to SpatialPoissonDisk — Poisson-disk sampling places anchors as far apart as possible (minimum-distance constraint), giving the most spatially-diverse training chips. Use it when training samples are expensive (annotated labels) and you want maximum coverage per chip.

field = RasterField(reader=refl)
train_patcher = gp.SpatialPatcher(
    geometry=gp.SpatialRectangular(size=(128, 128)),
    # jitter is in step-units (0.0 = none, 0.5 = ± half a step), so
    # 0.25 of a 192-px step gives ±48 px of jitter per anchor.
    sampler=gp.SpatialJitteredStride(step=(192, 192), jitter=0.25, seed=0),
    window=gp.SpatialBoxcar(),
    aggregation=gp.SpatialOverlapAdd(),
)
train_chips = list(train_patcher.split(field))
print(f"sampled {len(train_chips)} training chips of size 128×128")
sampled 160 training chips of size 128×128
def _rgb(arr):
    """(C, H, W) reflectance → (H, W, 3) display range."""
    rgb = np.transpose(arr[[2, 1, 0]], (1, 2, 0))  # R, G, B
    lo, hi = np.nanpercentile(rgb, [2, 98])
    return np.clip((rgb - lo) / (hi - lo + 1e-9), 0, 1)


fig, axes = plt.subplots(2, 4, figsize=(14, 7))
for ax, chip in zip(axes.flat, train_chips[:8], strict=False):
    ax.imshow(_rgb(np.asarray(chip.data)))
    ax.set_title(
        f"chip @ row={chip.indices.row_off}\n col={chip.indices.col_off}",
        fontsize=9,
    )
    ax.axis("off")
fig.suptitle("8 jittered training chips (128 × 128, ±48 px jitter)", y=1.0)
plt.tight_layout()
plt.show()
<Figure size 1400x700 with 8 Axes>

3. Augment one chip with gz.augment.Compose

A typical RS-safe augmentation pipeline: random flips, 90° rotations, brightness jitter (multiplicative gain), and Gaussian sensor noise. Each gz.augment.* op:

  • preserves CRS and updates the affine transform (so flips and rotations still map to the right world coordinates),
  • is reproducible — pass seed=... for a deterministic draw or omit for entropy from the OS,
  • composes via gz.augment.Compose(ops, p=...) so the whole chain has a single overall-probability gate.
augment = gz.augment.Compose(
    [
        gz.augment.RandomFlip(p_horizontal=0.5, p_vertical=0.5),
        gz.augment.RandomRotate90(p=0.75),
        gz.augment.BrightnessJitter(factor=(0.85, 1.15), per_band=True),
        gz.augment.GaussianNoise(sigma=0.01),
    ],
    p=1.0,
)
print(augment)

base = train_chips[0].data
fig, axes = plt.subplots(1, 5, figsize=(17, 4))
axes[0].imshow(_rgb(np.asarray(base)))
axes[0].set_title("original")
axes[0].axis("off")
for ax, seed in zip(axes[1:], range(4), strict=True):
    aug = augment(base, seed=seed)
    ax.imshow(_rgb(np.asarray(aug)))
    ax.set_title(f"augmented (seed={seed})")
    ax.axis("off")
fig.suptitle("Same chip, 4 augmented draws", y=1.02)
plt.tight_layout()
plt.show()
Compose(augmentations=[{'class': 'RandomFlip', 'config': {'p_horizontal': 0.5, 'p_vertical': 0.5, 'seed': None}}, {'class': 'RandomRotate90', 'config': {'p': 0.75, 'seed': None}}, {'class': 'BrightnessJitter', 'config': {'factor': (0.85, 1.15), 'per_band': True, 'seed': None}}, {'class': 'GaussianNoise', 'config': {'sigma': 0.01, 'per_band': True, 'seed': None}}], p=1.0, seed=None)
<Figure size 1700x400 with 5 Axes>

Notice each draw is different (flips, rotations, brightness shifts all sample independently per seed) but always preserves the CRS + spatial extent semantics. That’s the reproducible-RS-augmentation contract gz.augment enforces.

4. A three-class rule-based “model”

To keep this notebook self-contained we use a small NDVI / NDWI threshold rule as our classifier. Any callable with a predict method (sklearn) or __call__ (PyTorch / JAX) drops in via gz.ModelOp the same wayModelOp is the substrate adapter.

Classes:

idnamerule
0WaterNDWI > 0.1
1Bare / scrubNDVI < 0.25 and not water
2VegetationNDVI ≥ 0.25
class NDVIClassifier:
    """Rule-based 3-class classifier with the sklearn-style `predict` API."""

    def predict(self, x: np.ndarray) -> np.ndarray:
        # x is (4, H, W) reflectance — B(0) G(1) R(2) NIR(3).
        green, red, nir_ = x[1], x[2], x[3]
        ndvi = (nir_ - red) / (nir_ + red + 1e-10)
        ndwi = (green - nir_) / (green + nir_ + 1e-10)
        labels = np.full(ndvi.shape, 1, dtype=np.int64)  # default: bare/scrub
        labels[ndwi > 0.1] = 0  # water
        labels[ndvi >= 0.25] = 2  # vegetation
        return labels


model_op = gz.ModelOp(NDVIClassifier(), method="predict")
print(model_op)

# Smoke-test on a single chip.
sample_out = model_op(base)
print(f"single-chip output: shape={sample_out.shape}  dtype={sample_out.dtype}")
print(f"class counts:      {np.bincount(np.asarray(sample_out).ravel(), minlength=3)}")
ModelOp(model_type='NDVIClassifier', method='predict', batch_size=None)
single-chip output: shape=(128, 128)  dtype=int64
class counts:      [    0   926 15458]

5. Tiled inference — Sequential([GridSampler, ApplyToChips, Stitch])

Now the full pipeline. The per-chip op is ModelOp(...) instead of the NDVI from the previous notebook; the aggregation is SpatialHardVote(n_classes=3) instead of SpatialOverlapAdd because we are stitching integer class labels, not continuous values.

For inference we also use a denser, non-jittered sampler (SpatialRegularStride) — jitter is a training-time trick, not an inference-time one.

infer_patcher = gp.SpatialPatcher(
    geometry=gp.SpatialRectangular(size=(256, 256)),
    sampler=gp.SpatialRegularStride(step=(256, 256)),
    window=gp.SpatialBoxcar(),
    aggregation=gp.SpatialHardVote(n_classes=3),
)
inference = gz.Sequential(
    [
        GridSampler(patcher=infer_patcher),
        ApplyToChips(operator=model_op),
        Stitch(aggregation=infer_patcher.aggregation, domain=field.domain),
    ]
)
print(inference)

class_map = np.asarray(inference(field))
# Domain has channel axis (4, H, W); SpatialHardVote produces (4, H, W)
# with each channel a duplicate of the (H, W) vote map. Collapse:
if class_map.ndim == 3:
    class_map = class_map[0]
print(f"classification map: shape={class_map.shape}  dtype={class_map.dtype}")
print(f"class counts:      {np.bincount(class_map.ravel(), minlength=3)}")
Sequential([GridSampler(patcher={'geometry': {'class': 'SpatialRectangular', 'config': {'size': [256, 256]}}, 'sampler': {'class': 'SpatialRegularStride', 'config': {'step': [256, 256]}}, 'window': {'class': 'SpatialBoxcar', 'config': {}}, 'aggregation': {'class': 'SpatialHardVote', 'config': {'n_classes': 3}}}), ApplyToChips(operator={'class': 'ModelOp', 'config': {'model_type': 'NDVIClassifier', 'method': 'predict', 'batch_size': None}}), Stitch(aggregation={'class': 'SpatialHardVote', 'config': {'n_classes': 3}})])
classification map: shape=(3935, 1599)  dtype=int64
class counts:      [ 546953 4559300 1185812]

6. Compare against full-scene inference

Run the same model on the whole scene without patching as a reference. The patched output should match across the interior.

full_class_map = np.asarray(model_op(refl))
agree = float(np.mean(class_map == full_class_map))
print(f"per-pixel agreement: {agree * 100:.2f}%")

fig, axes = plt.subplots(1, 3, figsize=(18, 8))

labels = ["Water", "Bare/scrub", "Vegetation"]
class_cmap = ListedColormap(["#2c7bb6", "#fdae61", "#1a9850"])

axes[0].imshow(_rgb(np.asarray(refl)))
axes[0].set_title("True-color reference")
axes[0].axis("off")

im1 = axes[1].imshow(full_class_map, cmap=class_cmap, vmin=-0.5, vmax=2.5)
axes[1].set_title("Full-scene inference")
axes[1].axis("off")
cbar = fig.colorbar(im1, ax=axes[1], shrink=0.7, ticks=range(3))
cbar.ax.set_yticklabels(labels)

im2 = axes[2].imshow(class_map, cmap=class_cmap, vmin=-0.5, vmax=2.5)
axes[2].set_title("Tile-and-vote inference")
axes[2].axis("off")
cbar = fig.colorbar(im2, ax=axes[2], shrink=0.7, ticks=range(3))
cbar.ax.set_yticklabels(labels)
plt.show()
per-pixel agreement: 93.76%
<Figure size 1800x800 with 5 Axes>

7. Recap

The same three-op patching pattern from 05 — Patching: grid → process → stitch hosts the ML inference path with zero structural change — only the parts that need to vary do:

WhatPatching demoML demo
SamplerSpatialRegularStride (every pixel covered)SpatialJitteredStride (training) or SpatialRegularStride (inference)
WindowSpatialHann (smooth blending)SpatialBoxcar (votes don’t need blending)
Per-chip opNDVIgz.ModelOp(model, method="predict")
AggregationSpatialOverlapAdd (continuous map)SpatialHardVote(n_classes=K) (class labels)

Other building blocks worth pairing in:

  • gz.augmentCompose([RandomFlip, RandomRotate90, BrightnessJitter, GaussianNoise, AtmosphericHaze, SimulatedClouds, CutMix, ...]) for training-chip augmentation. Each op preserves CRS / transform so the augmented carrier is still geographically valid.
  • SklearnOpgz.learn.SklearnOp(estimator, mode="pixel") if you want a fitted sklearn estimator (RandomForestClassifier, GBM, kNN, …) instead of a hand-coded rule.
  • SpatialSoftVote — like HardVote but stitches per-class probabilities (chip outputs shape (K, H, W)), letting you threshold or compute class margins at the field level.
  • Profile + ShapeTrace from the observability section of operators_lake_tahoe — wrap ModelOp to time per-chip inference cost across the run.

For the cross-package walk-through (STAC → catalog → patcher → operators), see geocatalog/docs/notebooks/end_to_end_lake_tahoe.ipynb.