Skip to content

Preprocessing API

The pyrox.preprocessing subpackage holds the only pandas-touching code in pyrox. Layers, models, and inference runners stay pandas-free; this module is the bridge between user-supplied DataFrames and the JAX-only pyrox.nn layers.

SpatiotemporalFit

pyrox.preprocessing.SpatiotemporalFit

Bases: Module

Immutable bundle of fitted feature layers + time-encoding scalars.

Replaces bayesnf's mutable SpatiotemporalDataHandler with a pure PyTree so the whole bundle is JIT-friendly and picklable.

Attributes:

Name Type Description
standardize_layer Standardization

:class:Standardization layer applied to the feature columns at predict time.

fourier_layer FourierFeatures

:class:FourierFeatures layer (may have all-zero degrees if the user opted out of Fourier features).

seasonal_layer SeasonalFeatures

:class:SeasonalFeatures layer (zero-period if the user opted out of seasonal features).

interaction_layer InteractionFeatures

:class:InteractionFeatures layer (zero-pair if the user opted out).

time_min float

Minimum time value across the training set, used as an offset by :func:encode_time_column.

time_scale float

Multiplicative factor applied to (t - time_min) to produce the array passed to the seasonal layer (typically 1.0 for int time columns and a unit-conversion factor for datetime).

feature_cols tuple[str, ...]

Names of the columns the standardization/feature layers expect, in order.

target_col str

Name of the target column.

Source code in src/pyrox/preprocessing/_pandas.py
class SpatiotemporalFit(eqx.Module):
    """Immutable bundle of fitted feature layers + time-encoding scalars.

    Replaces bayesnf's mutable ``SpatiotemporalDataHandler`` with a pure
    PyTree so the whole bundle is JIT-friendly and picklable.

    Attributes:
        standardize_layer: :class:`Standardization` layer applied to the
            feature columns at predict time.
        fourier_layer: :class:`FourierFeatures` layer (may have all-zero
            degrees if the user opted out of Fourier features).
        seasonal_layer: :class:`SeasonalFeatures` layer (zero-period if
            the user opted out of seasonal features).
        interaction_layer: :class:`InteractionFeatures` layer
            (zero-pair if the user opted out).
        time_min: Minimum time value across the training set, used as
            an offset by :func:`encode_time_column`.
        time_scale: Multiplicative factor applied to ``(t - time_min)``
            to produce the array passed to the seasonal layer (typically
            ``1.0`` for ``int`` time columns and a unit-conversion
            factor for ``datetime``).
        feature_cols: Names of the columns the standardization/feature
            layers expect, in order.
        target_col: Name of the target column.
    """

    standardize_layer: Standardization
    fourier_layer: FourierFeatures
    seasonal_layer: SeasonalFeatures
    interaction_layer: InteractionFeatures
    time_min: float = eqx.field(static=True)
    time_scale: float = eqx.field(static=True)
    feature_cols: tuple[str, ...] = eqx.field(static=True)
    target_col: str = eqx.field(static=True)

fit_spatiotemporal

pyrox.preprocessing.fit_spatiotemporal(df, *, feature_cols, target_col, timetype='int', freq=None, seasonality_periods=(), num_seasonal_harmonics=(), fourier_degrees=(), interactions=(), standardize=None, time_col=0)

Build a complete :class:SpatiotemporalFit from a DataFrame.

The training-side workflow is::

fit = fit_spatiotemporal(df, feature_cols=..., target_col=...)

and the predict-side workflow re-uses the same fit to encode new data — concretely, by calling :func:encode_time_column with the stored time_min and applying the layers stored on the bundle.

Parameters:

Name Type Description Default
df DataFrame

Training DataFrame.

required
feature_cols Sequence[str]

Names of the input columns, in order. The first column (time_col=0) is treated as the time axis for seasonal features.

required
target_col str

Name of the target column.

required
timetype Literal['int', 'datetime']

"int" or "datetime" — see :func:encode_time_column.

'int'
freq str | None

Optional unit string for datetime time columns.

None
seasonality_periods Sequence[float]

Periods (in time-unit) for seasonal features. Empty ⇒ no seasonal features.

()
num_seasonal_harmonics Sequence[int]

Harmonics per period; same length as seasonality_periods.

()
fourier_degrees Sequence[int]

Per-input dyadic Fourier degrees. Length must match feature_cols; use 0 to skip a column.

()
interactions Sequence[tuple[int, int]]

Pairs of input-column indices for interaction features. Empty ⇒ no interactions.

()
standardize Sequence[str] | None

Optional subset of feature columns to standardize. None ⇒ standardize them all.

None
time_col int

Index of the time column inside feature_cols; defaults to 0.

0

Returns:

Name Type Description
Fitted SpatiotemporalFit

class:SpatiotemporalFit bundle.

Source code in src/pyrox/preprocessing/_pandas.py
def fit_spatiotemporal(
    df: pd.DataFrame,
    *,
    feature_cols: Sequence[str],
    target_col: str,
    timetype: Literal["int", "datetime"] = "int",
    freq: str | None = None,
    seasonality_periods: Sequence[float] = (),
    num_seasonal_harmonics: Sequence[int] = (),
    fourier_degrees: Sequence[int] = (),
    interactions: Sequence[tuple[int, int]] = (),
    standardize: Sequence[str] | None = None,
    time_col: int = 0,
) -> SpatiotemporalFit:
    """Build a complete :class:`SpatiotemporalFit` from a DataFrame.

    The training-side workflow is::

        fit = fit_spatiotemporal(df, feature_cols=..., target_col=...)

    and the predict-side workflow re-uses the *same* ``fit`` to encode
    new data — concretely, by calling :func:`encode_time_column` with
    the stored ``time_min`` and applying the layers stored on the
    bundle.

    Args:
        df: Training DataFrame.
        feature_cols: Names of the input columns, in order. The first
            column (``time_col=0``) is treated as the time axis for
            seasonal features.
        target_col: Name of the target column.
        timetype: ``"int"`` or ``"datetime"`` — see
            :func:`encode_time_column`.
        freq: Optional unit string for ``datetime`` time columns.
        seasonality_periods: Periods (in time-unit) for seasonal
            features. Empty ⇒ no seasonal features.
        num_seasonal_harmonics: Harmonics per period; same length as
            ``seasonality_periods``.
        fourier_degrees: Per-input dyadic Fourier degrees. Length must
            match ``feature_cols``; use ``0`` to skip a column.
        interactions: Pairs of input-column indices for interaction
            features. Empty ⇒ no interactions.
        standardize: Optional subset of feature columns to standardize.
            ``None`` ⇒ standardize them all.
        time_col: Index of the time column inside ``feature_cols``;
            defaults to 0.

    Returns:
        Fitted :class:`SpatiotemporalFit` bundle.
    """
    feature_cols_list = list(feature_cols)
    time_col_name = feature_cols_list[time_col]
    if standardize is None:
        # Default: standardize every column except the time column. The
        # time axis is shifted by `time_min` (and scaled for datetime) in
        # `encode_time_column`, and downstream blocks — especially the
        # seasonal features — interpret `seasonality_periods` in the
        # *original* time units. Z-scoring time would rescale those
        # periods implicitly and miscalibrate the seasonal basis.
        standardize = [col for i, col in enumerate(feature_cols_list) if i != time_col]
    # Build a full-size Standardization layer aligned with `feature_cols`,
    # with identity (mu=0, std=1) on columns not in `standardize`. That
    # lets downstream code apply the layer to the entire design matrix
    # without re-indexing.
    standardize_set = set(standardize)
    missing = standardize_set - set(feature_cols_list)
    if missing:
        raise ValueError(
            "standardize must be a subset of feature_cols; "
            f"missing from feature_cols: {sorted(missing)}"
        )
    if time_col_name in standardize_set:
        # The time column gets shifted (and possibly scaled) by
        # `encode_time_column` before standardization is applied. Fitting
        # mu/std on the *raw* time values would then mis-shift the
        # encoded column by an extra constant — invariably catastrophic
        # for Unix-like timestamps. Force the user to opt out instead.
        raise ValueError(
            f"Cannot standardize the time column ({time_col_name!r}); "
            "the time axis is already centered by `time_min` in "
            "`encode_time_column`. Drop it from `standardize`."
        )
    if standardize_set:
        sub_layer = fit_standardization(df, list(standardize))
        sub_mu = {col: sub_layer.mu[i] for i, col in enumerate(standardize)}
        sub_std = {col: sub_layer.std[i] for i, col in enumerate(standardize)}
    else:
        sub_mu, sub_std = {}, {}
    mu_full = jnp.stack(
        [
            sub_mu[col] if col in standardize_set else jnp.float32(0.0)
            for col in feature_cols_list
        ]
    )
    std_full = jnp.stack(
        [
            sub_std[col] if col in standardize_set else jnp.float32(1.0)
            for col in feature_cols_list
        ]
    )
    standardize_layer = Standardization(mu=mu_full, std=std_full)

    # If fourier_degrees was not provided, default to all-zero (no Fourier).
    if not fourier_degrees:
        fourier_degrees = [0] * len(feature_cols)
    if len(fourier_degrees) != len(feature_cols):
        raise ValueError(
            f"fourier_degrees must have length {len(feature_cols)}, "
            f"got {len(fourier_degrees)}"
        )
    # Seasonal periods + harmonics must be the same length: the BNF's
    # seasonal block guards on `any(harmonics)`, so a misaligned
    # `harmonics=()` would silently drop *all* requested seasonality.
    # Fail fast here.
    if len(seasonality_periods) != len(num_seasonal_harmonics):
        raise ValueError(
            "seasonality_periods and num_seasonal_harmonics must have the same "
            f"length; got {len(seasonality_periods)} periods and "
            f"{len(num_seasonal_harmonics)} harmonics"
        )
    fourier_layer = FourierFeatures(
        degrees=tuple(int(d) for d in fourier_degrees),
        rescale=True,
    )
    seasonal_layer = SeasonalFeatures(
        periods=tuple(float(p) for p in seasonality_periods),
        harmonics=tuple(int(h) for h in num_seasonal_harmonics),
        rescale=True,
    )
    interaction_layer = InteractionFeatures(
        pairs=tuple((int(a), int(b)) for a, b in interactions),
    )

    # Encode the time column once to capture (time_min, time_scale).
    _, time_min, time_scale = encode_time_column(
        df[feature_cols[time_col]], timetype=timetype, freq=freq
    )

    return SpatiotemporalFit(  # ty: ignore[invalid-return-type]
        standardize_layer=standardize_layer,
        fourier_layer=fourier_layer,
        seasonal_layer=seasonal_layer,
        interaction_layer=interaction_layer,
        time_min=time_min,
        time_scale=time_scale,
        feature_cols=tuple(feature_cols),
        target_col=target_col,
    )

fit_standardization

pyrox.preprocessing.fit_standardization(df, columns, *, eps=1e-12)

Build a :class:Standardization layer from per-column mean / std.

Parameters:

Name Type Description Default
df DataFrame

Source DataFrame.

required
columns Sequence[str]

Columns to standardize, in the order they will appear in the array passed to :meth:Standardization.__call__.

required
eps float

Floor for the standard deviation; protects against division by zero on constant columns.

1e-12

Returns:

Type Description
Standardization

class:Standardization layer.

Source code in src/pyrox/preprocessing/_pandas.py
def fit_standardization(
    df: pd.DataFrame,
    columns: Sequence[str],
    *,
    eps: float = 1e-12,
) -> Standardization:
    """Build a :class:`Standardization` layer from per-column mean / std.

    Args:
        df: Source DataFrame.
        columns: Columns to standardize, in the order they will appear
            in the array passed to :meth:`Standardization.__call__`.
        eps: Floor for the standard deviation; protects against
            division by zero on constant columns.

    Returns:
        :class:`Standardization` layer.
    """
    sub = df[list(columns)]
    mu = jnp.asarray(sub.mean().to_numpy(), dtype=jnp.float32)
    std = jnp.asarray(sub.std(ddof=0).to_numpy(), dtype=jnp.float32)
    std = jnp.maximum(std, eps)
    return Standardization(mu=mu, std=std)  # ty: ignore[invalid-return-type]

encode_time_column

pyrox.preprocessing.encode_time_column(series, *, timetype='int', freq=None, time_min=None)

Convert a pandas time column into a unit-scale JAX float array.

For timetype="int", the series is cast directly to float32 and offset by its minimum (or by time_min if supplied).

For timetype="datetime", the series is converted to integer nanoseconds, offset by its minimum, and divided by a unit factor derived from freq ("D" ⇒ days, "H" ⇒ hours, "W" ⇒ weeks). When freq is None, the unit is days.

Parameters:

Name Type Description Default
series Series

1D time column.

required
timetype Literal['int', 'datetime']

"int" for already-numeric series, "datetime" for pd.Timestamp-valued series.

'int'
freq str | None

Optional unit string for the datetime path.

None
time_min float | None

Optional fixed offset (use the value from a previous fit to align test data with training).

None

Returns:

Type Description
Float[Array, ' N']

(t, time_min, time_scale) — the encoded array, the offset

float

used, and the multiplicative scale (1 for the int path,

float

1 / nanoseconds-per-unit for datetime).

Source code in src/pyrox/preprocessing/_pandas.py
def encode_time_column(
    series: pd.Series,
    *,
    timetype: Literal["int", "datetime"] = "int",
    freq: str | None = None,
    time_min: float | None = None,
) -> tuple[Float[Array, " N"], float, float]:
    """Convert a pandas time column into a unit-scale JAX float array.

    For ``timetype="int"``, the series is cast directly to ``float32``
    and offset by its minimum (or by ``time_min`` if supplied).

    For ``timetype="datetime"``, the series is converted to integer
    nanoseconds, offset by its minimum, and divided by a unit factor
    derived from ``freq`` (``"D"`` ⇒ days, ``"H"`` ⇒ hours, ``"W"`` ⇒
    weeks). When ``freq`` is ``None``, the unit is days.

    Args:
        series: 1D time column.
        timetype: ``"int"`` for already-numeric series, ``"datetime"``
            for ``pd.Timestamp``-valued series.
        freq: Optional unit string for the datetime path.
        time_min: Optional fixed offset (use the value from a previous
            ``fit`` to align test data with training).

    Returns:
        ``(t, time_min, time_scale)`` — the encoded array, the offset
        used, and the multiplicative scale (``1`` for the ``int`` path,
        ``1 / nanoseconds-per-unit`` for ``datetime``).
    """
    if timetype == "int":
        # Do the offset subtraction in numpy float64 before casting to
        # JAX float32: common "integer" time columns are Unix
        # seconds/milliseconds, which exceed float32's ~7 decimal digits.
        # A direct float32 cast collapses neighboring timestamps to the
        # same value, silently destroying unit-level time deltas. (We use
        # numpy here rather than `jnp.float64` because JAX defaults to
        # float32 and would silently downcast unless `jax_enable_x64` is
        # set, defeating the purpose.)
        import numpy as np

        arr64 = np.asarray(series.to_numpy(), dtype=np.float64)
        if time_min is None:
            time_min = float(arr64.min())
        centered = arr64 - time_min
        return jnp.asarray(centered, dtype=jnp.float32), float(time_min), 1.0
    if timetype == "datetime":
        # Keep everything in int64 (numpy-side) to avoid float32 precision
        # loss — nanoseconds since epoch overflow float32 badly.
        #
        # Cast explicitly to `datetime64[ns]` because pandas 2.x may pick
        # microsecond resolution by default, which would make the int64
        # representation off by a factor of 1000 from what `unit_ns`
        # below assumes.
        import numpy as np

        ts = pd.to_datetime(series)
        # tz-aware columns can't be cast straight to `datetime64[ns]`
        # (pandas raises). Normalize to UTC and drop the tz so the int64
        # epoch representation is well-defined and consistent across
        # different input timezones.
        if getattr(ts.dt, "tz", None) is not None:
            ts = ts.dt.tz_convert("UTC").dt.tz_localize(None)
        ns_np = ts.astype("datetime64[ns]").astype("int64").to_numpy()
        unit_ns_table = {
            None: 24 * 3600 * 1_000_000_000,
            "D": 24 * 3600 * 1_000_000_000,
            "H": 3600 * 1_000_000_000,
            "W": 7 * 24 * 3600 * 1_000_000_000,
            "min": 60 * 1_000_000_000,
        }
        if freq not in unit_ns_table:
            raise ValueError(
                f"Unsupported freq {freq!r}; expected one of "
                f"{sorted(k for k in unit_ns_table if k is not None)} or None"
            )
        unit_ns = unit_ns_table[freq]
        scale = 1.0 / float(unit_ns)
        if time_min is None:
            time_min = float(ns_np.min()) * scale
        # Subtract the offset in ns (int64), *then* scale; keeps the
        # result well within float32 range.
        origin_ns = int(time_min / scale)
        centered = (ns_np - origin_ns).astype(np.float64)
        return jnp.asarray(centered * scale, dtype=jnp.float32), float(time_min), scale
    raise ValueError(f"Unsupported timetype {timetype!r}")