Source code for wetlandmapper.terrain

"""
terrain.py
----------
DEM-based terrain analysis for masking high-altitude artefacts in wetland mapping.

Glaciers, permanent snowpacks, and steep mountain terrain produce strong
positive wetness index signals that are not wetlands. This module provides
tools to identify and mask these artefacts using a Digital Elevation Model
(DEM) by characterising local terrain flatness — the key topographic
criterion that separates true wetlands (flat, low-lying areas) from
high-altitude false positives.

Three flatness metrics are provided, each serving a different purpose
-------------------------------------------------------------------
compute_slope(dem)
    Most broadly applicable. A wetland pixel almost never sits on a slope >5°.
    Computed from numpy gradient on the DEM with approximate metre conversion
    for geographic coordinates. Independent of window size.

compute_tpi(dem, window)
    Adds discrimination between flat plateau (near-zero TPI, high elevation)
    and valley bottom (negative TPI). Slope alone can't separate a flat glacier
    from a flat floodplain. TPI can, if you combine it with an elevation ceiling.

compute_local_range(dem, window)
    Directly replicates the GEE rolling-window approach. Most intuitive to
    threshold (e.g. "< 30 m variation in a 5×5 window"). Maximum minus minimum
    elevation within an NxN window. A low local range indicates flat terrain.

map_dem_depressions(raw_dem, filled_dem, ...)
    Depression mapping from raw and pit-filled DEM using integer division and
    binary reclassification (depression=1, non-depression=0), following the
    protocol described by Sinha et al. (2017, Current Science).

mask_terrain_artifacts(wetness, dem, ...)
    Combines any or all three metrics plus an elevation ceiling. Emits a warning
    if the mask retains <10% of pixels (likely thresholds are too strict).

Notes
-----
All functions operate on xarray DataArrays with spatial dimensions
``(y, x)`` or ``(lat, lon)``. CRS is preserved where rioxarray is available.
No GEE dependency — these functions are designed for use on locally downloaded
data (from :func:`wetlandmapper.gee.fetch` or any other source).

For server-side DEM masking within GEE (without downloading the DEM), use
the ``dem_mask`` parameter of :func:`wetlandmapper.gee.fetch`.
"""

from __future__ import annotations

import numpy as np
import xarray as xr

try:
    import rioxarray  # noqa: F401

    _HAS_RIO = True
except ImportError:
    _HAS_RIO = False

__all__ = [
    "compute_slope",
    "compute_tpi",
    "compute_local_range",
    "map_dem_depressions",
    "mask_terrain_artifacts",
]


# ---------------------------------------------------------------------------
# Internal helpers


def _check_dem(dem: xr.DataArray) -> None:
    if not isinstance(dem, xr.DataArray):
        raise TypeError(f"DEM must be an xarray.DataArray, got {type(dem)}")
    spatial = {"y", "x"} | {"lat", "lon"}
    if not any(dim in spatial for dim in dem.dims):
        raise ValueError(
            "DEM must contain spatial dimensions 'y'/'x' or 'lat'/'lon'. "
            f"Found dimensions: {dem.dims}"
        )


def _spatial_dims(da: xr.DataArray) -> tuple[str, str]:
    """Return the (y_dim, x_dim) names, supporting both y/x and lat/lon."""
    if "y" in da.dims and "x" in da.dims:
        return "y", "x"
    if "lat" in da.dims and "lon" in da.dims:
        return "lat", "lon"
    raise ValueError(
        f"Cannot identify spatial dimensions in {da.dims}. "
        "Expected 'y'/'x' or 'lat'/'lon'."
    )


# ---------------------------------------------------------------------------
# Slope


[docs] def compute_slope( dem: xr.DataArray, units: str = "degrees", ) -> xr.DataArray: """Compute terrain slope from a DEM DataArray. Uses central-difference gradient (numpy.gradient) applied to the spatial dimensions. The result represents the steepness of terrain at each pixel — low values indicate flat ground suitable for wetlands. Parameters ---------- dem : xr.DataArray Digital Elevation Model with spatial dimensions ``(y, x)`` or ``(lat, lon)``. Units should be metres. units : {"degrees", "radians", "percent"} Output slope units. Default ``"degrees"``. Returns ------- xr.DataArray Slope values, same spatial shape as ``dem``. Name: ``"slope"``. """ _check_dem(dem) if units not in {"degrees", "radians", "percent"}: raise ValueError( f"units must be one of 'degrees', 'radians', or 'percent'. " f"Got {units!r}." ) y_dim, x_dim = _spatial_dims(dem) y_coords = dem[y_dim].values x_coords = dem[x_dim].values if y_coords.size < 2 or x_coords.size < 2: raise ValueError( "DEM must contain at least two values in each spatial dimension." ) elev = dem.values.astype(float) dy_m = abs(float(np.mean(np.diff(y_coords)))) * 111_320.0 mid_lat_rad = np.deg2rad(float(np.mean(y_coords))) dx_m = abs(float(np.mean(np.diff(x_coords)))) * 111_320.0 * np.cos(mid_lat_rad) grad_y, grad_x = np.gradient(elev, dy_m, dx_m) slope_rad = np.arctan(np.sqrt(grad_x**2 + grad_y**2)) if units == "degrees": slope_val = np.degrees(slope_rad) elif units == "radians": slope_val = slope_rad else: slope_val = np.tan(slope_rad) * 100.0 result = xr.DataArray(slope_val, dims=dem.dims, coords=dem.coords) result.name = "slope" result.attrs.update( long_name=f"Terrain Slope ({units})", units=units, source="Computed from DEM using central-difference gradient", ) if _HAS_RIO: try: crs = dem.rio.crs if crs is not None: result = result.rio.write_crs(crs) except Exception: pass return result
# --------------------------------------------------------------------------- # TPI — Topographic Position Index
[docs] def compute_tpi( dem: xr.DataArray, window: int = 5, ) -> xr.DataArray: """Compute the Topographic Position Index (TPI). TPI = elevation - focal_mean(elevation, window x window) Positive TPI indicates a pixel is higher than its surroundings (hilltop, ridge); negative TPI indicates a valley or depression; near-zero TPI indicates flat terrain or a mid-slope position. Parameters ---------- dem : xr.DataArray Digital Elevation Model with spatial dimensions ``(y, x)`` or ``(lat, lon)``. window : int Size of the square focal window in pixels. Odd numbers produce a symmetric neighbourhood. Default 5 (5 x 5 = 25-pixel window). Returns ------- xr.DataArray TPI values in the same elevation units as ``dem`` (usually metres). Name: ``"TPI"``. """ _check_dem(dem) if not isinstance(window, int) or window < 3: raise ValueError("window must be an integer >= 3.") y_dim, x_dim = _spatial_dims(dem) focal_mean = dem.rolling( {y_dim: window, x_dim: window}, center=True, min_periods=1 ).mean() tpi = dem - focal_mean tpi.name = "TPI" tpi.attrs.update( long_name="Topographic Position Index", window_size=window, interpretation=( "Positive = hilltop/ridge; " "Negative = valley/depression; " "Near-zero = flat terrain or mid-slope" ), ) return tpi
# --------------------------------------------------------------------------- # Local elevation range
[docs] def compute_local_range( dem: xr.DataArray, window: int = 5, ) -> xr.DataArray: """Compute local elevation range (max - min) within a rolling window. A low local range indicates flat terrain. This directly replicates the rolling-window approach used in GEE scripts to retain only flat neighbourhoods for wetland mapping. Parameters ---------- dem : xr.DataArray Digital Elevation Model with spatial dimensions ``(y, x)``. window : int Square window size in pixels. Default 5 (5 x 5 window). Returns ------- xr.DataArray Local elevation range in the same units as ``dem`` (metres). Name: ``"local_range"``. """ _check_dem(dem) if not isinstance(window, int) or window < 3: raise ValueError("window must be an integer >= 3.") y_dim, x_dim = _spatial_dims(dem) rolling = dem.rolling({y_dim: window, x_dim: window}, center=True, min_periods=1) local_range = rolling.max() - rolling.min() local_range.name = "local_range" local_range.attrs.update( long_name="Local Elevation Range", window_size=window, units=dem.attrs.get("units", "m"), interpretation=( "Low values indicate flat terrain; " "high values indicate steep or rough terrain" ), ) return local_range
[docs] def map_dem_depressions( raw_dem: xr.DataArray, filled_dem: xr.DataArray, *, require_integer: bool = True, apply_cleanup: bool = True, cleanup_window: int = 3, min_neighbours: int = 2, ) -> xr.DataArray: """Map topographic depressions from raw and pit-filled DEMs. This implements a depression protocol used in floodplain wetland mapping: 1. Integer-divide ``raw_dem / filled_dem``. 2. Pixels with value 1 are unchanged terrain (no pit). 3. Pixels with value 0 are depressions (raw < filled). 4. Reclassify to a binary mask: depression=1, non-depression=0. Parameters ---------- raw_dem : xr.DataArray Original (unfilled) DEM. filled_dem : xr.DataArray Pit-filled DEM created from ``raw_dem``. require_integer : bool If ``True`` (default), both DEMs must have integer dtype. apply_cleanup : bool If ``True`` (default), remove isolated one-pixel/very small speckles using a neighbourhood-count filter. cleanup_window : int Square rolling window size for cleanup. Default 3. min_neighbours : int Minimum number of depression pixels (including self) within ``cleanup_window`` to retain a depression pixel. Default 2. Returns ------- xr.DataArray Binary depression mask with values {0, 1}. Name: ``"depression_mask"``. Notes ----- This method performs best in low-relief floodplains and may be less reliable in rugged terrain. Residual speckle can still occur and may need additional post-processing for specific study areas. References ---------- Sinha, R., Saxena, S., & Singh, M. (2017). Protocols for Riverine Wetland Mapping and Classification Using Remote Sensing and GIS. Current Science, 112(7), 1544-1552. http://www.jstor.org/stable/24912702 """ _check_dem(raw_dem) _check_dem(filled_dem) if raw_dem.dims != filled_dem.dims or raw_dem.shape != filled_dem.shape: raise ValueError( "raw_dem and filled_dem must have identical dimensions and shape." ) if require_integer: if not np.issubdtype(raw_dem.dtype, np.integer): raise TypeError("raw_dem must be integer dtype when require_integer=True.") if not np.issubdtype(filled_dem.dtype, np.integer): raise TypeError("filled_dem must be integer dtype when require_integer=True.") if apply_cleanup: if not isinstance(cleanup_window, int) or cleanup_window < 3: raise ValueError("cleanup_window must be an integer >= 3.") if not isinstance(min_neighbours, int) or min_neighbours < 1: raise ValueError("min_neighbours must be an integer >= 1.") raw = raw_dem.astype(np.int64) filled = filled_dem.astype(np.int64) # Integer protocol: unchanged terrain gives 1, depressions give 0. ratio = xr.where(filled != 0, raw // filled, 1) depression_mask = xr.where(ratio == 0, 1, 0).astype(np.uint8) if apply_cleanup: y_dim, x_dim = _spatial_dims(depression_mask) neighbour_count = depression_mask.rolling( {y_dim: cleanup_window, x_dim: cleanup_window}, center=True, min_periods=1, ).sum() depression_mask = xr.where( (depression_mask == 1) & (neighbour_count < min_neighbours), 0, depression_mask, ).astype(np.uint8) depression_mask.name = "depression_mask" depression_mask.attrs.update( long_name="DEM depression mask from raw vs pit-filled DEM", values="1=depression/wetland candidate, 0=flat/non-depression", method="integer_division_raw_over_filled", cleanup_applied=bool(apply_cleanup), cleanup_window=int(cleanup_window), min_neighbours=int(min_neighbours), ) if _HAS_RIO: try: crs = raw_dem.rio.crs if crs is not None: depression_mask = depression_mask.rio.write_crs(crs) except Exception: pass return depression_mask
# --------------------------------------------------------------------------- # Combined terrain masking
[docs] def mask_terrain_artifacts( wetness: "xr.DataArray | xr.Dataset", dem: xr.DataArray, max_slope: float | None = 5.0, max_tpi: float | None = None, max_local_range: float | None = None, local_range_window: int = 5, tpi_window: int = 5, max_elevation: float | None = None, invert: bool = False, ) -> "xr.DataArray | xr.Dataset": """Mask wetness data using terrain flatness and elevation filters. Parameters ---------- wetness : xr.DataArray or xr.Dataset Wetness or index data to mask. Can include a ``time`` dimension. dem : xr.DataArray Digital Elevation Model with spatial dimensions matching ``wetness``. max_slope : float or None Maximum slope in degrees. Pixels steeper than this are masked. Default 5.0. Set to ``None`` to disable slope filtering. max_tpi : float or None Maximum absolute TPI (metres). Default ``None`` (disabled). max_local_range : float or None Maximum local elevation range (metres). Default ``None`` (disabled). local_range_window : int Window size for local range. Default 5. tpi_window : int Window size for TPI. Default 5. max_elevation : float or None Absolute elevation ceiling (metres). Pixels above this are masked. Default ``None``. invert : bool If ``True``, invert the mask so excluded terrain is retained. Returns ------- xr.DataArray or xr.Dataset ``wetness`` masked by terrain suitability. The type matches the input. """ if not isinstance(wetness, (xr.DataArray, xr.Dataset)): raise TypeError( f"wetness must be an xarray.DataArray or Dataset, got {type(wetness)}" ) _check_dem(dem) if max_local_range is not None and ( not isinstance(local_range_window, int) or local_range_window < 3 ): raise ValueError("local_range_window must be an integer >= 3.") if max_tpi is not None and (not isinstance(tpi_window, int) or tpi_window < 3): raise ValueError("tpi_window must be an integer >= 3.") terrain_mask = xr.ones_like(dem, dtype=bool) if max_elevation is not None: terrain_mask = terrain_mask & (dem <= max_elevation) if max_slope is not None: terrain_mask = terrain_mask & (compute_slope(dem) <= max_slope) if max_tpi is not None: tpi_result = abs(compute_tpi(dem, window=tpi_window)) <= max_tpi terrain_mask = terrain_mask & tpi_result if max_local_range is not None: lr_result = compute_local_range(dem, window=local_range_window) <= max_local_range terrain_mask = terrain_mask & lr_result if invert: terrain_mask = ~terrain_mask if isinstance(wetness, xr.Dataset): masked = wetness.where(terrain_mask) else: masked = wetness.where(terrain_mask) if _HAS_RIO and isinstance(masked, xr.DataArray): try: crs = dem.rio.crs if crs is not None: masked = masked.rio.write_crs(crs) except Exception: pass return masked