Source code for wetlandmapper.dynamics

"""Wetland dynamics classification and temporal aggregation utilities.

Copyright (c) 2026, Manudeo Singh
Author: Manudeo Singh, March 2026
"""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    import xarray as xr

try:
    import xarray as xr
except ImportError as e:
    raise ImportError("xarray is required. Install: pip install wetlandmapper") from e

try:
    from rioxarray import exceptions as _rio_exc  # noqa: F401
    _HAS_RIO = True
except ImportError:
    _HAS_RIO = False

# ---------------------------------------------------------------------------
# Class definitions
# ---------------------------------------------------------------------------

DYNAMICS_CLASSES: dict[int, str] = {
    0:  "Non-wetland",
    2:  "New",
    3:  "Lost",
    4:  "Diminishing",
    5:  "Intensifying",
    6:  "Intermittent",
    10: "Persistent",
}

DYNAMICS_COLORS: dict[int, str] = {
    0:  "#d5d8dc",
    2:  "#8e44ad",
    3:  "#c0392b",
    4:  "#e67e22",
    5:  "#2ecc71",
    6:  "#76d7c4",
    10: "#1a5276",
}

_VALID_NAN_POLICIES = ("total", "valid")


# ---------------------------------------------------------------------------
# Main classification function
# ---------------------------------------------------------------------------

[docs] def classify_dynamics( water_index: xr.DataArray, nYear: int = 3, thresholdWet: float = 25.0, thresholdPersis: float = 75.0, water_threshold: float = 0.0, nan_policy: str = "total", min_valid_obs: int | None = None, # backward-compatibility alias only — no mndwi= alias (function is index-agnostic) mndwi_threshold: float | None = None, ) -> xr.DataArray: """Classify wetland pixels into six mutually exclusive temporal dynamics classes. Each pixel is assigned to exactly one class based on three temporal summary statistics derived from a multi-year water index stack: * Overall wet frequency *W%* (% of years above ``water_threshold``) * Historic wet count *W_historic* (first ``nYear`` years) * Recent wet count *W_recent* (last ``nYear`` years) **Class priority (strictly exclusive — a pixel receives exactly one class):** ========= ==== =================================================== Class Code Primary condition ========= ==== =================================================== Persistent 10 W% ≥ thresholdPersis New 2 W_historic = 0 AND W_recent > 0 (newly appeared) Lost 3 W_historic > 0 AND W_recent = 0 (fully gone) Intensifying 5 W% ≥ thresholdWet AND delta > 0 (not New) Diminishing 4 W% ≥ thresholdWet AND delta < 0 (not Lost) Intermittent 6 W% ≥ thresholdWet AND no directional signal Non-wetland 0 W% < thresholdWet ========= ==== =================================================== Each priority is applied only to **unclassified** pixels (code = 0), preventing any pixel from receiving more than one class code even when multiple conditions are simultaneously true (e.g. a pixel that is both Persistent *and* Intensifying will be classified as Persistent only). Parameters ---------- water_index : xr.DataArray Multi-temporal water index time series with a ``time`` dimension. Accepts MNDWI, AWEIsh, AWEInsh, NDWI, or any index where positive values indicate surface water. nYear : int Length of the historic and recent windows in years. Default 3. thresholdWet : float Minimum wet frequency (%) for a pixel to be any wetland class. Default 25. thresholdPersis : float Wet frequency (%) above which a pixel is Persistent. Must be greater than ``thresholdWet``. Default 75. water_threshold : float Index value above which a pixel is counted as wet in a given year. Default 0.0 (positive MNDWI = water-dominated; Xu 2006). nan_policy : {"total", "valid"} Denominator used when computing wet frequency. ``"total"`` *(default)* Denominator = total number of time steps. NaN pixels count as dry. Reproduces the original Singh & Sinha (2022) method. Appropriate when NaN values are rare or randomly distributed. ``"valid"`` Denominator = per-pixel count of non-NaN observations. Wet frequency is the fraction of *cloud-free* years that were wet. The historic/recent windows are also normalised by their own valid counts so that ``delta`` is expressed as a fraction in [−1, +1]. Use when cloud masking produces substantial or spatially clustered NaN values. min_valid_obs : int, optional *(Only active when nan_policy="valid".)* Minimum number of non-NaN observations required for classification. Pixels with fewer valid observations are set to NaN in the output. Default ``None`` (no minimum enforced). mndwi_threshold : float, optional **Deprecated.** Use ``water_threshold`` instead. Returns ------- xr.DataArray Integer DataArray (dtype int8) of shape ``(y, x)`` with class codes from :data:`DYNAMICS_CLASSES`. Guaranteed to contain only valid class codes — no additive artefacts. Raises ------ ValueError If ``water_index`` lacks a ``time`` dimension, if ``nYear * 2 > n_time``, if thresholds are out of range, or if ``nan_policy`` is not one of the accepted values. References ---------- Singh, M. & Sinha, R. (2022). Remote Sensing Letters, 13(1), 1–13. https://doi.org/10.1080/2150704X.2021.1980919 """ # ------------------------------------------------------------------ # Backward-compatibility shim # ------------------------------------------------------------------ if mndwi_threshold is not None: warnings.warn( "'mndwi_threshold' is deprecated; use 'water_threshold' instead.", DeprecationWarning, stacklevel=2, ) water_threshold = mndwi_threshold # ------------------------------------------------------------------ # Input validation # ------------------------------------------------------------------ if "time" not in water_index.dims: raise ValueError( "'water_index' must have a 'time' dimension. " f"Found: {water_index.dims}" ) n_time = len(water_index.time) if nYear * 2 > n_time: raise ValueError( f"nYear={nYear} requires at least {nYear * 2} time steps, " f"but the input has only {n_time}." ) if not (0 <= thresholdWet <= 100) or not (0 <= thresholdPersis <= 100): raise ValueError("Thresholds must be in the range [0, 100].") if thresholdPersis <= thresholdWet: raise ValueError( f"thresholdPersis ({thresholdPersis}) must be greater than " f"thresholdWet ({thresholdWet})." ) if nan_policy not in _VALID_NAN_POLICIES: raise ValueError( f"nan_policy must be one of {_VALID_NAN_POLICIES}, " f"got {nan_policy!r}." ) # ------------------------------------------------------------------ # Stage 1: Summary statistics # ------------------------------------------------------------------ his_slice = water_index.isel(time=slice(0, nYear)) rec_slice = water_index.isel(time=slice(-nYear, None)) if nan_policy == "total": # Original method: NaN → dry, denominator = n_time wb = xr.where(water_index > water_threshold, 1, 0) wb_his = xr.where(his_slice > water_threshold, 1, 0) wb_rec = xr.where(rec_slice > water_threshold, 1, 0) wall = wb.sum(dim="time") whistoric = wb_his.sum(dim="time") wrecent = wb_rec.sum(dim="time") w_percent = (wall / n_time) * 100 # Trend window: raw counts in [0, nYear] # ── Classification helpers (total mode) ─────────────────────── def _new(wh, wr): return (wh == 0) & (wr > 0) def _lost(wh, wr): return (wh > 0) & (wr == 0) def _intens(wp, wh, wr): return (wp >= thresholdWet) & (wr > wh) & ~_new(wh, wr) def _dimin(wp, wh, wr): return (wp >= thresholdWet) & (wr < wh) & ~_lost(wh, wr) wh_arg, wr_arg = whistoric, wrecent else: # Valid mode: per-pixel denominator def _safe_mean(da, dim): """Fraction of valid observations that were wet.""" wet = (da > water_threshold).where(da.notnull()) n_v = da.count(dim=dim) return wet.sum(dim=dim, skipna=True) / n_v.where(n_v > 0) f_total = _safe_mean(water_index, "time") f_his = _safe_mean(his_slice, "time") f_rec = _safe_mean(rec_slice, "time") n_valid = water_index.count(dim="time") w_percent = f_total * 100 # ── Classification helpers (valid mode) ─────────────────────── def _new(fh, fr): return (fh == 0) & (fr > 0) def _lost(fh, fr): return (fh > 0) & (fr == 0) def _intens(wp, fh, fr): return (wp >= thresholdWet) & (fr > fh) & ~_new(fh, fr) def _dimin(wp, fh, fr): return (wp >= thresholdWet) & (fr < fh) & ~_lost(fh, fr) wh_arg, wr_arg = f_his, f_rec # ------------------------------------------------------------------ # Stage 2: Exclusive priority classification # # Each rule uses the guard `classification == 0` so that a pixel # already assigned a class is never overwritten. This prevents # additive artefacts (e.g. Persistent=10 + Intensifying=5 → 15) # that arise when multiple conditions are simultaneously true. # # Priority order (high → low): # Persistent (10) → New (2) → Lost (3) → Intensifying (5) # → Diminishing (4) → Intermittent (6) → Non-wetland (0) # ------------------------------------------------------------------ unset = lambda c: c == 0 # noqa: E731 helper for readability classification = xr.zeros_like(w_percent, dtype=np.int8) # 1 — Persistent classification = xr.where( unset(classification) & (w_percent >= thresholdPersis), np.int8(10), classification, ) # 2 — New classification = xr.where( unset(classification) & _new(wh_arg, wr_arg), np.int8(2), classification, ) # 3 — Lost classification = xr.where( unset(classification) & _lost(wh_arg, wr_arg), np.int8(3), classification, ) # 4 — Intensifying classification = xr.where( unset(classification) & _intens(w_percent, wh_arg, wr_arg), np.int8(5), classification, ) # 5 — Diminishing classification = xr.where( unset(classification) & _dimin(w_percent, wh_arg, wr_arg), np.int8(4), classification, ) # 6 — Intermittent classification = xr.where( unset(classification) & (w_percent >= thresholdWet), np.int8(6), classification, ) # 0 — Non-wetland: pixels still 0 (below thresholdWet) remain # ------------------------------------------------------------------ # Stage 3: Mask insufficient-data pixels (valid mode only) # ------------------------------------------------------------------ if nan_policy == "valid" and min_valid_obs is not None: classification = classification.where(n_valid >= min_valid_obs) # ------------------------------------------------------------------ # Sanity check: no pixel should have a code outside valid set # ------------------------------------------------------------------ unexpected = ( set(np.unique(classification.values.astype(int)).tolist()) - set(DYNAMICS_CLASSES.keys()) ) valid_codes = np.array(list(DYNAMICS_CLASSES.keys()), dtype=np.int8) # (this is a no-cost check on the non-NaN values) assert ( np.isin( classification.values[~np.isnan(classification.values.astype(float))], valid_codes ).all() ), ( "classify_dynamics produced invalid class codes — this is a bug. " f"Unexpected values: {unexpected}" ) # ------------------------------------------------------------------ # Preserve CRS if available # ------------------------------------------------------------------ if _HAS_RIO: try: crs = water_index.rio.crs if crs is not None: classification = classification.rio.write_crs(crs) except Exception as e: warnings.warn(f"Could not write CRS to output: {e}.", stacklevel=2) # ------------------------------------------------------------------ # Metadata # ------------------------------------------------------------------ classification.name = ( f"dynamics_nYear{nYear}_wet{thresholdWet}_persis{thresholdPersis}" ) classification.attrs.update( long_name="Wetland Temporal Dynamics Class", nYear=nYear, thresholdWet=thresholdWet, thresholdPersis=thresholdPersis, water_threshold=water_threshold, nan_policy=nan_policy, n_timesteps=int(n_time), class_codes=str(DYNAMICS_CLASSES), references=( "Singh & Sinha (2022). Remote Sensing Letters, 13(1), 1-13. " "https://doi.org/10.1080/2150704X.2021.1980919" ), ) return classification
# --------------------------------------------------------------------------- # Wet frequency # ---------------------------------------------------------------------------
[docs] def compute_wet_frequency( water_index: xr.DataArray, water_threshold: float = 0.0, nan_policy: str = "total", mndwi_threshold: float | None = None, ) -> xr.DataArray: """Return the pixel-wise wet frequency (%) across the full time series. Parameters ---------- water_index : xr.DataArray Multi-temporal water index with a ``time`` dimension. water_threshold : float Index value above which a pixel is counted as wet. Default 0.0. nan_policy : {"total", "valid"} ``"total"`` (default): denominator = total time steps; NaN = dry. ``"valid"``: denominator = per-pixel non-NaN count. mndwi_threshold : float, optional **Deprecated.** Use ``water_threshold`` instead. Returns ------- xr.DataArray Wet frequency in percent (0–100), shape ``(y, x)``. """ if mndwi_threshold is not None: warnings.warn( "'mndwi_threshold' is deprecated; use 'water_threshold' instead.", DeprecationWarning, stacklevel=2, ) water_threshold = mndwi_threshold if nan_policy not in _VALID_NAN_POLICIES: raise ValueError( f"nan_policy must be one of {_VALID_NAN_POLICIES}, got {nan_policy!r}." ) if nan_policy == "total": wb = xr.where(water_index > water_threshold, 1, 0) freq = (wb.sum(dim="time") / len(water_index.time)) * 100 else: wet = (water_index > water_threshold).where(water_index.notnull()) n_v = water_index.count(dim="time") freq = wet.sum(dim="time", skipna=True) / n_v.where(n_v > 0) * 100 freq.name = "wet_frequency_pct" freq.attrs.update( long_name="Wet Frequency (%)", nan_policy=nan_policy, water_threshold=water_threshold, ) return freq
# --------------------------------------------------------------------------- # Temporal aggregation utility # ---------------------------------------------------------------------------
[docs] def aggregate_time( da: "xr.DataArray | xr.Dataset", freq: str = "annual", method: str = "median", ) -> "xr.DataArray | xr.Dataset": """Temporally aggregate a multi-temporal xarray object before classification. Reduces a time series to one composite per chosen period by computing a pixel-wise statistic within each period. Useful for: - **Dynamics**: produce annual composites from all available scenes rather than using every raw overpass. - **WCT**: produce monthly or seasonal composites, then classify each with :func:`~wetlandmapper.classify_wct` / :func:`~wetlandmapper.classify_wct_ema`. Parameters ---------- da : xr.DataArray or xr.Dataset Input data with a ``time`` dimension. Accepts both GEE-fetched and locally constructed objects. freq : {"annual", "monthly", "seasonal", "all"} Aggregation period: ``"annual"`` One composite per calendar year (resampled to year-end). ``"monthly"`` One composite per calendar month. ``"seasonal"`` One composite per meteorological season per year: DJF (Dec–Jan–Feb), MAM (Mar–Apr–May), JJA (Jun–Jul–Aug), SON (Sep–Oct–Nov). Uses ``pandas`` quarterly resampling anchored to December. ``"all"`` No aggregation — returns ``da`` unchanged. method : {"median", "mean", "max", "min"} Pixel-wise statistic computed within each period. Default ``"median"``. Returns ------- xr.DataArray or xr.Dataset Same type as ``da`` with a reduced ``time`` dimension (one step per period). For ``"seasonal"``, the time coordinate is labelled with the first day of each quarter (e.g., ``2003-12-01`` for DJF 2004). Raises ------ ValueError If ``freq`` or ``method`` is not one of the valid options. Examples -------- Produce annual MNDWI composites from a dense time series: >>> from wetlandmapper.dynamics import aggregate_time >>> mndwi_annual = aggregate_time(mndwi_ts, freq="annual") >>> dynamics = classify_dynamics(mndwi_annual, nYear=3) Produce seasonal composites for WCT classification: >>> from wetlandmapper import compute_indices, classify_wct_ema >>> from wetlandmapper.dynamics import aggregate_time >>> indices_ts = fetch(aoi, "2010-01-01", "2023-12-31", ... index=["MNDWI","NDVI","NDTI"]) >>> seasonal = aggregate_time(indices_ts, freq="seasonal") >>> # Classify each season independently >>> for t in seasonal.time: ... wct = classify_wct_ema(seasonal.sel(time=t)) Notes ----- Pixels that are NaN (masked by cloud or no-data) in all scenes within a period remain NaN in the composite. The statistic is computed ignoring NaNs (``skipna=True`` is the xarray default). """ _VALID_FREQ = {"annual", "monthly", "seasonal", "all"} _VALID_METHOD = {"median", "mean", "max", "min"} if freq not in _VALID_FREQ: raise ValueError(f"freq must be one of {_VALID_FREQ}. Got {freq!r}.") if method not in _VALID_METHOD: raise ValueError(f"method must be one of {_VALID_METHOD}. Got {method!r}.") if freq == "all": return da _RESAMPLE_RULE = { "annual": "YE", "monthly": "ME", # QS-DEC anchors quarters to December: DJF / MAM / JJA / SON "seasonal": "QS-DEC", } resampled = da.resample(time=_RESAMPLE_RULE[freq]) _AGG = { "median": resampled.median, "mean": resampled.mean, "max": resampled.max, "min": resampled.min, } result = _AGG[method]() # Carry forward the name (DataArray only) if isinstance(da, xr.DataArray) and da.name: result.name = da.name # Descriptive attribute result.attrs["temporal_aggregation"] = freq result.attrs["aggregation_method"] = method return result