Source code for wetlandmapper.gee

"""
gee.py

# Copyright (c) 2026, Manudeo Singh          #
# Author: Manudeo Singh, March 2026          #
------
Optional Google Earth Engine (GEE) data acquisition module.

Supports **all Landsat missions** (4, 5, 7, 8, 9), Sentinel-2, and **MODIS** (Terra/Aqua),
including a ``"LandsatAll"`` option that automatically merges available missions for any
requested date range.

Two retrieval functions
-----------------------
fetch(aoi, ...)
    Downloads data immediately via GEE's ``getDownloadURL`` API.
    Best for small-to-medium AOIs (up to ~100 km²).

fetch_xee(aoi, ...)
    Opens the GEE collection as a **lazy Dask-backed xarray** via ``xee``.
    No data is transferred until ``.compute()`` is called.
    Suitable for large AOIs and long time series.

Temporal aggregation
--------------------
Both functions accept a ``temporal_aggregation`` parameter:

    ``"all"``      — every individual scene (default)
    ``"annual"``   — one median composite per calendar year
    ``"monthly"``  — one median composite per calendar month
    ``"seasonal"`` — one composite per meteorological season (DJF/MAM/JJA/SON)

Sensors and date coverage
------------------------------------
+--------------+---------------------+---------------------------+--------------------+
| sensor=      | GEE collection      | Operational dates         | Band family        |
+--------------+---------------------+---------------------------+--------------------+
| "Landsat4"   | LT04/C02/T1_L2      | 1982-08-22 – 1993-12-14   | TM (SR_B1–B5,B7)   |
| "Landsat5"   | LT05/C02/T1_L2      | 1984-03-16 – 2013-06-05   | TM (SR_B1–B5,B7)   |
| "Landsat7"   | LE07/C02/T1_L2      | 1999-04-15 – 2022-04-06   | ETM+ (SR_B1–B5,B7) |
|              |                     |   SLC failure: 2003-06-01 | use_slc_off=False  |
| "Landsat8"   | LC08/C02/T1_L2      | 2013-04-11 – present      | OLI (SR_B2–B7)     |
| "Landsat9"   | LC09/C02/T1_L2      | 2021-10-31 – present      | OLI-2 (SR_B2–B7)   |
| "LandsatAll" | merged above        | 1982 – present            | auto-harmonised    |
| "Sentinel2"  | S2_SR_HARMONIZED    | 2015-06-27 – present      | MSI (B2–B12)       |
| "MODIS_Terra"| MOD09A1             | 2000-02-24 – present      | MODIS (500m)       |
| "MODIS_Aqua" | MYD09A1             | 2002-07-04 – present      | MODIS (500m)       |
| "MODISAll"   | merged MODIS        | 2002-07-04 – present      | MODIS (500m)       |
+-------------+---------------------+---------------------------+--------------------+

``"Landsat"`` is an alias for ``"Landsat8"`` for backward compatibility.

Landsat 7 SLC-off note
-----------------------
The Scan Line Corrector (SLC) on Landsat 7 ETM+ failed on 2003-05-31.
Images acquired after this date have wedge-shaped data gaps covering roughly
22 % of each scene.  Use ``use_slc_off=False`` (default) to exclude these
images and use only the good-quality 1999–2003 record.  Set ``use_slc_off=True``
to include post-failure images (useful when other sensors have no coverage, e.g.
1999–2012 before Landsat 8).

MODIS note
----------
MODIS provides 500m resolution surface reflectance composites (8-day intervals).
Use ``"MODISAll"`` to automatically merge Terra and Aqua for continuous coverage.
MODIS has coarser resolution but longer temporal record (2000–present) compared
to Landsat. Suitable for regional-scale studies where 500m resolution is adequate.
Band mapping differs from Landsat — MODIS stores Red as Band 1, NIR as Band 2,
Blue as Band 3, Green as Band 4. Cloud masking uses StateQA bits 0-1 (cloud state)
and bit 2 (shadow). Scale factor is 0.0001 with no offset (unlike Landsat C02 L2's
-0.2 offset). AWEIsh and AWEInsh are computed server-side for MODIS since all
required bands are available.

DEM masking
-----------
Server-side DEM masking using Copernicus GLO-30 can be activated with the
``dem_mask`` parameter in :func:`fetch`. This applies terrain filters directly
in GEE using ``ee.Terrain.slope()`` and ``reduceNeighborhood``, avoiding the
need to download the DEM. The snow question is addressed by ``min_temp_c`` in
climate-adaptive mode (ERA5 precipitation includes snowfall; filtering to months
≥5°C removes cold months where precipitation is frozen) and by ``max_elevation_m``
(above the local glaciation line, always mask regardless of other criteria).

Requirements
------------
- earthengine-api  : always required
- rasterio         : required by fetch()
- xee + dask       : required by fetch_xee()

Install:  ``pip install wetlandmapper[gee]``
Authenticate once:  ``earthengine authenticate``
"""

from __future__ import annotations

import datetime
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence, TypeAlias, cast

import numpy as np
import xarray as xr

# ---------------------------------------------------------------------------
# Optional dependency guard
# ---------------------------------------------------------------------------

if TYPE_CHECKING:
    import ee as ee_module

    EEGeometry: TypeAlias = ee_module.Geometry
    EEImage: TypeAlias = ee_module.Image
    EEImageCollection: TypeAlias = ee_module.ImageCollection
    EENumber: TypeAlias = ee_module.Number
else:
    EEGeometry = Any
    EEImage = Any
    EEImageCollection = Any
    EENumber = Any

try:
    import ee as _ee

    ee = cast(Any, _ee)
    _HAS_EE = True
except ImportError:
    _HAS_EE = False
    ee = cast(Any, None)

__all__ = ["fetch", "fetch_xee", "authenticate", "init"]


# ---------------------------------------------------------------------------
# Sensor configuration
# ---------------------------------------------------------------------------

# Internal band-family keys (not all exposed to the user)
_BAND_MAP: dict[str, dict[str, str]] = {
    # Landsat 4/5 TM and Landsat 7 ETM+ — SR_B1..B5, B7; B6 is thermal
    "LandsatTM_ETM": {
        "blue": "SR_B1",
        "green": "SR_B2",
        "red": "SR_B3",
        "nir": "SR_B4",
        "swir": "SR_B5",  # SWIR1 (band 5 on TM/ETM+)
        "swir2": "SR_B7",  # SWIR2 (band 7 on TM/ETM+)
        "qa": "QA_PIXEL",
    },
    # Landsat 8/9 OLI — SR_B2..B7 (extra coastal/aerosol B1 pushed others up by one)
    "LandsatOLI": {
        "blue": "SR_B2",
        "green": "SR_B3",
        "red": "SR_B4",
        "nir": "SR_B5",
        "swir": "SR_B6",  # SWIR1
        "swir2": "SR_B7",
        "qa": "QA_PIXEL",
    },
    # Sentinel-2 MSI
    "Sentinel2": {
        "blue": "B2",
        "green": "B3",
        "red": "B4",
        "nir": "B8",
        "swir": "B11",  # SWIR1 (20 m; resampled by GEE to 10 m)
        "swir2": "B12",
        "qa": "QA60",
    },
    # MODIS Terra/Aqua MOD09A1 / MYD09A1 (8-day 500m surface reflectance)
    # Scale: multiply by 0.0001 (stored as int16, range -100 to 16000)
    "MODIS_500m": {
        "blue": "sur_refl_b03",   # 459-479 nm
        "green": "sur_refl_b04",  # 545-565 nm
        "red": "sur_refl_b01",    # 620-670 nm
        "nir": "sur_refl_b02",    # 841-876 nm
        "swir": "sur_refl_b06",   # SWIR1 1628-1652 nm
        "swir2": "sur_refl_b07",  # SWIR2 2105-2155 nm
        "qa": "StateQA",
    },
    # Common renamed bands used internally after harmonising LandsatAll
    "_harmonised": {
        "blue": "blue",
        "green": "green",
        "red": "red",
        "nir": "nir",
        "swir": "swir1",
        "swir2": "swir2",
        "qa": "qa",
    },
}

_COLLECTION_ID: dict[str, str] = {
    "Landsat4": "LANDSAT/LT04/C02/T1_L2",
    "Landsat5": "LANDSAT/LT05/C02/T1_L2",
    "Landsat7": "LANDSAT/LE07/C02/T1_L2",
    "Landsat8": "LANDSAT/LC08/C02/T1_L2",
    "Landsat9": "LANDSAT/LC09/C02/T1_L2",
    "Sentinel2": "COPERNICUS/S2_SR_HARMONIZED",
    "MODIS_Terra": "MODIS/061/MOD09A1",

    "MODIS_Aqua": "MODIS/061/MYD09A1",
}

# Map user-facing sensor name → internal band-family key
_SENSOR_BAND_FAMILY: dict[str, str] = {
    "Landsat4": "LandsatTM_ETM",
    "Landsat5": "LandsatTM_ETM",
    "Landsat7": "LandsatTM_ETM",
    "Landsat8": "LandsatOLI",
    "Landsat9": "LandsatOLI",
    "Sentinel2": "Sentinel2",
    "MODIS_Terra": "MODIS_500m",
    "MODIS_Aqua": "MODIS_500m",
    # LandsatAll and MODISAll handled separately
}

# Backward-compat alias
_SENSOR_ALIASES: dict[str, str] = {
    "Landsat": "Landsat8",
}

# Scale factors: all Landsat C02 L2 use the same formula; S2/MODIS divide by 10000
_SCALE_FACTOR: dict[str, dict[str, float]] = {
    "LandsatTM_ETM": {"scale": 0.0000275, "offset": -0.2},
    "LandsatOLI": {"scale": 0.0000275, "offset": -0.2},
    "Sentinel2": {"scale": 0.0001, "offset": 0.0},
    "MODIS_500m": {"scale": 0.0001, "offset": 0.0},
}

# Cloud-cover image property per sensor family
_CLOUD_COVER_PROP: dict[str, str] = {
    "LandsatTM_ETM": "CLOUD_COVER",
    "LandsatOLI": "CLOUD_COVER",
    "Sentinel2": "CLOUDY_PIXEL_PERCENTAGE",
    "MODIS_500m": "CLOUD_COVER",  # not used — MODIS uses pixel-level QA
}

# Approximate operational date ranges for LandsatAll auto-selection
_LANDSAT_DATE_RANGES: dict[str, tuple[str, str]] = {
    "Landsat4": ("1982-08-22", "1993-12-15"),
    "Landsat5": ("1984-03-16", "2013-06-06"),
    "Landsat7": ("1999-04-15", "2022-04-07"),
    "Landsat8": ("2013-04-11", "2099-01-01"),
    "Landsat9": ("2021-10-31", "2099-01-01"),
}

# Date after which Landsat 7 SLC-off images should be excluded
_L7_SLC_FAILURE_DATE = "2003-06-01"

# Meteorological seasons: name → (months, label_month, label_day)
_SEASONS: dict[str, tuple[list[int], int, int]] = {
    "DJF": ([12, 1, 2], 1, 15),
    "MAM": ([3, 4, 5], 4, 15),
    "JJA": ([6, 7, 8], 7, 15),
    "SON": ([9, 10, 11], 10, 15),
}

_VALID_AGGREGATIONS = {"all", "annual", "monthly", "seasonal"}
_VALID_INDICES = {"MNDWI", "NDWI", "NDVI", "NDTI", "AWEIsh", "AWEInsh"}
_VALID_REDUCTION_METHODS = {"median", "mean", "percentile"}
_VALID_SINGLE_SENSORS = set(_COLLECTION_ID.keys()) | set(_SENSOR_ALIASES.keys())
_ALL_VALID_SENSORS = _VALID_SINGLE_SENSORS | {"LandsatAll", "MODISAll"}


# ---------------------------------------------------------------------------
# Authentication helpers
# ---------------------------------------------------------------------------


[docs] def authenticate() -> None: """Run interactive GEE authentication (opens browser; only needed once).""" _require_ee() ee.Authenticate()
[docs] def init(project: str | None = None) -> None: """Initialise the GEE Python client. Parameters ---------- project : str, optional GEE cloud project ID. Required for accounts created after 2023. """ _require_ee() if project: ee.Initialize(project=project) else: ee.Initialize()
# --------------------------------------------------------------------------- # Cloud masking # --------------------------------------------------------------------------- def _mask_landsat_clouds(image: EEImage) -> EEImage: """Mask clouds and cloud shadows using QA_PIXEL bits 3 and 4 (Landsat C02 L2).""" qa = image.select("QA_PIXEL") mask = ( qa.bitwiseAnd(1 << 3) .eq(0) # bit 3 = cloud .And(qa.bitwiseAnd(1 << 4).eq(0)) # bit 4 = cloud shadow ) return image.updateMask(mask) def _mask_sentinel2_clouds(image: EEImage) -> EEImage: """Mask opaque clouds (bit 10) and cirrus (bit 11) using QA60 (S2 SR).""" qa = image.select("QA60") mask = qa.bitwiseAnd(1 << 10).eq(0).And(qa.bitwiseAnd(1 << 11).eq(0)) return image.updateMask(mask) def _mask_modis_clouds(image: EEImage) -> EEImage: """Mask clouds and cloud shadows using MODIS StateQA bits. StateQA bit layout (MOD09A1 / MYD09A1): Bits 0-1: cloud state (00 = clear, 01 = cloudy, 10 = mixed) Bit 2: cloud shadow (1 = shadow present) Only pixels where bits 0-1 == 0 (clear) AND bit 2 == 0 (no shadow) are retained. """ qa = image.select("StateQA") cloud_state = qa.bitwiseAnd(3).eq(0) # bits 0-1: clear only cloud_shadow = qa.bitwiseAnd(1 << 2).eq(0) # bit 2: no shadow return image.updateMask(cloud_state.And(cloud_shadow)) # --------------------------------------------------------------------------- # Index computation (server-side, on already-scaled reflectance) # --------------------------------------------------------------------------- def _add_indices(image: EEImage, bands: dict[str, str]) -> EEImage: """Add supported index bands to a GEE image using normalizedDifference. Indices are evaluated server-side as (A - B) / (A + B). MNDWI = (Green - SWIR1) / (Green + SWIR1) — sensitive to open water NDWI = (Green - NIR ) / (Green + NIR ) — classic open-water index NDVI = (NIR - Red ) / (NIR + Red ) — sensitive to green vegetation NDTI = (Red - Green) / (Red + Green) — sensitive to turbid / sediment water Band assignments per sensor family: +-------+-----------------------+-----------------------+-------------------+ | Index | TM / ETM+ | OLI (L8/L9) | Sentinel-2 | +-------+-----------------------+-----------------------+-------------------+ | MNDWI | (SR_B2 - SR_B5) / sum | (SR_B3 - SR_B6) / sum | (B3 - B11) / sum | | NDWI | (SR_B2 - SR_B4) / sum | (SR_B3 - SR_B5) / sum | (B3 - B8 ) / sum | | NDVI | (SR_B4 - SR_B3) / sum | (SR_B5 - SR_B4) / sum | (B8 - B4 ) / sum | | NDTI | (SR_B3 - SR_B2) / sum | (SR_B4 - SR_B3) / sum | (B4 - B3 ) / sum | +-------+-----------------------+-----------------------+-------------------+ """ mndwi = image.normalizedDifference([bands["green"], bands["swir"]]).rename("MNDWI") ndwi = image.normalizedDifference([bands["green"], bands["nir"]]).rename("NDWI") ndvi = image.normalizedDifference([bands["nir"], bands["red"]]).rename("NDVI") ndti = image.normalizedDifference([bands["red"], bands["green"]]).rename("NDTI") # AWEIsh = Blue + 2.5*Green - 1.5*(NIR + SWIR1) - 0.25*SWIR2 # Note: Landsat C02 L2 reflectance is already scaled (0.0000275*DN - 0.2); # the 0.0001 constant from the original Feyisa et al. (2014) formula is omitted. aweish = ( image.select(bands["blue"]) .add(image.select(bands["green"]).multiply(2.5)) .subtract( image.select(bands["nir"]) .add(image.select(bands["swir"])) .multiply(1.5) ) .subtract(image.select(bands["swir2"]).multiply(0.25)) .rename("AWEIsh") ) # AWEInsh = 4*(Green - SWIR1) - (0.25*NIR + 2.75*SWIR1) aweinsh = ( image.select(bands["green"]) .subtract(image.select(bands["swir"])) .multiply(4.0) .subtract( image.select(bands["nir"]).multiply(0.25) .add(image.select(bands["swir"]).multiply(2.75)) ) .rename("AWEInsh") ) derived_bands = ee.Image.cat([mndwi, ndwi, ndvi, ndti, aweish, aweinsh]) return cast(EEImage, image.addBands(derived_bands)) def _add_custom_indices( image: EEImage, bands: dict[str, str], custom_indices: dict[str, str], ) -> EEImage: """Add user-defined index bands from ee.Image.expression formulas. Formula symbols are the harmonised band names: ``blue``, ``green``, ``red``, ``nir``, ``swir`` (alias ``swir1``), ``swir2``, and ``qa``. """ variables = { "blue": image.select(bands["blue"]), "green": image.select(bands["green"]), "red": image.select(bands["red"]), "nir": image.select(bands["nir"]), "swir": image.select(bands["swir"]), "swir1": image.select(bands["swir"]), "swir2": image.select(bands["swir2"]), "qa": image.select(bands["qa"]), } new_bands = [ image.expression(formula, variables).rename(name) for name, formula in custom_indices.items() ] derived_bands = ee.Image.cat(new_bands) return cast(EEImage, image.addBands(derived_bands)) # --------------------------------------------------------------------------- # LandsatAll: build merged, harmonised collection # --------------------------------------------------------------------------- def _build_landsat_all( ee_geom: EEGeometry, start: str, end: str, max_cloud_cover: float, use_slc_off: bool, ) -> EEImageCollection: """Build a merged Landsat 4-9 collection with harmonised band names. Each mission's sub-collection is filtered to the requested date range (intersected with the mission's operational window), cloud-masked, scaled to surface reflectance, and bands renamed to a common scheme: blue, green, red, nir, swir1, swir2, qa The merged collection can then be processed identically regardless of which missions contributed images to a given period. Parameters ---------- ee_geom : ee.Geometry Study area geometry (used for ``filterBounds``). start, end : str ISO 8601 date strings. max_cloud_cover : float Maximum cloud cover (%) per image. use_slc_off : bool If False, Landsat 7 images after 2003-05-31 are excluded. Returns ------- ee.ImageCollection Merged, harmonised collection with bands: ``MNDWI``, ``NDVI``, ``NDTI`` (only index bands are kept). """ sub_collections = [] for mission in ["Landsat4", "Landsat5", "Landsat7", "Landsat8", "Landsat9"]: op_start, op_end = _LANDSAT_DATE_RANGES[mission] # Intersect requested range with operational window eff_start = max(start[:10], op_start) eff_end = min(end[:10], op_end) if eff_start >= eff_end: continue # mission not active in requested period # For Landsat 7, optionally exclude SLC-off data if mission == "Landsat7" and not use_slc_off: eff_end = min(eff_end, _L7_SLC_FAILURE_DATE) if eff_start >= eff_end: continue # SLC-on period entirely outside requested range col = ( ee.ImageCollection(_COLLECTION_ID[mission]) .filterBounds(ee_geom) .filterDate(eff_start, eff_end) .filter(ee.Filter.lt("CLOUD_COVER", max_cloud_cover)) ) band_family = _SENSOR_BAND_FAMILY[mission] sf = _SCALE_FACTOR[band_family] bm = _BAND_MAP[band_family] # Cloud mask + scale to reflectance col = col.map(_mask_landsat_clouds) col = col.map( lambda img: ( img.multiply(sf["scale"]) .add(sf["offset"]) .copyProperties(img, ["system:time_start"]) ) ) # Rename to common harmonised band names col = col.map( lambda img: img.select( [ bm["blue"], bm["green"], bm["red"], bm["nir"], bm["swir"], bm["swir2"], bm["qa"], ], ["blue", "green", "red", "nir", "swir1", "swir2", "qa"], ) ) sub_collections.append(col) if not sub_collections: raise RuntimeError( "No Landsat missions have data in the requested date range " f"[{start}, {end}] with use_slc_off={use_slc_off}." ) # Merge all sub-collections merged = sub_collections[0] for col in sub_collections[1:]: merged = merged.merge(col) # Add spectral indices using harmonised band names harm_bands = _BAND_MAP["_harmonised"] merged = merged.map(lambda img: _add_indices(img, harm_bands)) return merged def _build_modis_all( ee_geom: EEGeometry, start: str, end: str, max_cloud_cover: float, ) -> EEImageCollection: """Merge MODIS Terra (MOD09A1) and Aqua (MYD09A1) into one collection. Both collections use the same MODIS_500m band family. Terra has an equatorial crossing time of ~10:30 AM and Aqua ~1:30 PM, giving approximately twice the sampling frequency when combined. Both are 8-day composites so the merged collection has ~16-day sampling. Parameters ---------- ee_geom : ee.Geometry start, end : str ISO 8601 date strings. max_cloud_cover : float Not used for per-image filtering (MODIS uses pixel-level QA only); kept for API consistency. Returns ------- ee.ImageCollection Merged Terra + Aqua collection with MNDWI / NDVI / NDTI bands. """ bm = _BAND_MAP["MODIS_500m"] sf = _SCALE_FACTOR["MODIS_500m"] def _prep(collection_id): col = ( ee.ImageCollection(collection_id).filterBounds(ee_geom).filterDate(start, end) ) col = col.map(_mask_modis_clouds) col = col.map( lambda img: ( img.multiply(sf["scale"]).copyProperties(img, ["system:time_start"]) ) ) col = col.map(lambda img: _add_indices(img, bm)) return col terra = _prep("MODIS/061/MOD09A1") aqua = _prep("MODIS/061/MYD09A1") return terra.merge(aqua) # --------------------------------------------------------------------------- # Server-side temporal compositing with empty-period safeguard # --------------------------------------------------------------------------- def _make_nan_image(bands: list[str], timestamp: EENumber) -> EEImage: """Create a constant all-masked image with the specified band names. Used as a fallback when a compositing period contains no valid images. The image has all pixels masked so NaN propagates naturally into the xarray output — no actual NaN constant is needed. """ # Build a multi-band constant image, then mask all pixels img = ee.Image.cat([ee.Image.constant(0).rename(b) for b in bands]).updateMask( ee.Image.constant(0) ) # mask = 0 everywhere → all pixels masked # Explicit float cast ensures type consistency with real bands return img.float().set("system:time_start", timestamp) def _normalize_reduction_method(reduction_method: str) -> str: """Validate and normalize the collection reduction method.""" method = reduction_method.lower() if method not in _VALID_REDUCTION_METHODS: raise ValueError( f"reduction_method must be one of {sorted(_VALID_REDUCTION_METHODS)}, " f"got {reduction_method!r}." ) return method def _validate_percentile(percentile: float) -> None: """Validate percentile input for percentile reduction.""" if not 0.0 <= percentile <= 100.0: raise ValueError(f"percentile must be between 0 and 100 inclusive, got {percentile!r}.") def _format_percentile_token(percentile: float) -> str: """Format a percentile value into a stable band-suffix token.""" if float(percentile).is_integer(): return str(int(percentile)) return str(percentile).replace(".", "_") def _build_composites( collection: EEImageCollection, temporal_aggregation: str, start: str, end: str, index_bands: list[str], reduction_method: str = "median", percentile: float = 50.0, ) -> EEImageCollection: """Reduce an ImageCollection to one composite per chosen period. Parameters ---------- collection : ee.ImageCollection Pre-processed collection with index bands added (MNDWI, NDVI, NDTI). temporal_aggregation : str One of ``"all"``, ``"annual"``, ``"monthly"``, ``"seasonal"``. start, end : str ISO 8601 date strings for the full requested range. index_bands : list of str Band names present in the collection (e.g. ``["MNDWI", "NDVI"]``). reduction_method : str One of "median", "mean", "percentile". Default "median". percentile : float Percentile value if reduction_method is "percentile". Default 50.0. Returns ------- ee.ImageCollection One image per period, ``system:time_start`` set to the period midpoint. For ``"all"``, the input collection is returned unchanged. """ reduction_method = _normalize_reduction_method(reduction_method) if reduction_method == "percentile": _validate_percentile(percentile) if temporal_aggregation == "all": collection = collection.map( lambda img: img.float().set("system:time_start", img.get("system:time_start")) ) return collection start_dt = datetime.date.fromisoformat(start[:10]) end_dt = datetime.date.fromisoformat(end[:10]) start_yr = start_dt.year end_yr = end_dt.year images = [] def _safe_composite(period_col: EEImageCollection, timestamp: EENumber) -> EEImage: """Return a composite or a masked fallback image, server-side.""" if reduction_method == "median": real = period_col.median().float().set("system:time_start", timestamp) elif reduction_method == "mean": real = period_col.mean().float().set("system:time_start", timestamp) elif reduction_method == "percentile": percentile_token = _format_percentile_token(percentile) reducer = ee.Reducer.percentile( [percentile], [f"p{percentile_token}"], ) real = ( period_col.reduce(reducer) .select( [f"{band}_p{percentile_token}" for band in index_bands], index_bands, ) .float() .set("system:time_start", timestamp) ) else: raise AssertionError(f"Unhandled reduction_method: {reduction_method}") fallback = _make_nan_image(index_bands, timestamp) return ee.Image(ee.Algorithms.If(period_col.size().gt(0), real, fallback)) if temporal_aggregation == "annual": for yr in range(start_yr, end_yr + 1): col = collection.filterDate(f"{yr}-01-01", f"{yr + 1}-01-01") ts = ee.Date.fromYMD(yr, 7, 1).millis() images.append(_safe_composite(col, ts)) elif temporal_aggregation == "monthly": cur = start_dt.replace(day=1) while cur.year < end_yr or (cur.year == end_yr and cur.month <= end_dt.month): yr, mo = cur.year, cur.month nxt = ( datetime.date(yr + 1, 1, 1) if mo == 12 else datetime.date(yr, mo + 1, 1) ) col = collection.filterDate(cur.isoformat(), nxt.isoformat()) ts = ee.Date.fromYMD(yr, mo, 15).millis() images.append(_safe_composite(col, ts)) cur = nxt elif temporal_aggregation == "seasonal": for yr in range(start_yr, end_yr + 1): for sname, (months, tag_mo, tag_day) in _SEASONS.items(): season_parts = [] for mo in months: mo_yr = yr - 1 if (sname == "DJF" and mo == 12) else yr mo_start = datetime.date(mo_yr, mo, 1) mo_end = ( datetime.date(mo_yr + 1, 1, 1) if mo == 12 else datetime.date(mo_yr, mo + 1, 1) ) season_parts.append( collection.filterDate(mo_start.isoformat(), mo_end.isoformat()) ) merged = season_parts[0] for sc in season_parts[1:]: merged = merged.merge(sc) ts = ee.Date.fromYMD(yr, tag_mo, tag_day).millis() images.append(_safe_composite(merged, ts)) if not images: raise RuntimeError( f"No composites generated for temporal_aggregation={temporal_aggregation!r}, " f"date range [{start}, {end}]." ) return ee.ImageCollection.fromImages(images) # --------------------------------------------------------------------------- # Core download helpers # --------------------------------------------------------------------------- def _parse_aoi(aoi: "dict | str | Path") -> EEGeometry: """Convert an AOI to an ``ee.Geometry``. Accepts three forms: 1. **GeoJSON dict** — plain geometry, Feature, or FeatureCollection. 2. **Shapefile path** — ``str`` or ``pathlib.Path`` pointing to a ``.shp`` file (or any format readable by ``geopandas``/``fiona``). All features are dissolved into a single geometry (union) so the AOI always has one contiguous boundary. 3. **GeoJSON file path** — a ``.geojson`` / ``.json`` file is read with ``geopandas`` and treated the same as a shapefile. Parameters ---------- aoi : dict, str, or Path The area of interest. If a file path, the file is read with ``geopandas`` (must be installed: ``pip install geopandas``). The CRS is reprojected to WGS 84 (EPSG:4326) if necessary. Returns ------- ee.Geometry GEE geometry suitable for ``filterBounds`` and ``getDownloadURL``. Raises ------ ImportError If a file path is provided but ``geopandas`` is not installed. FileNotFoundError If the specified file does not exist. ValueError If the dissolved geometry is empty (e.g. empty shapefile). """ from pathlib import Path as _Path # ── File path branch ──────────────────────────────────────────────────── if isinstance(aoi, (str, _Path)): path = _Path(aoi) if not path.exists(): raise FileNotFoundError(f"AOI file not found: {path}") try: import geopandas as gpd except ImportError: raise ImportError( "Reading a shapefile or GeoJSON file requires geopandas.\n" "Install it with: pip install geopandas" ) gdf = gpd.read_file(path) # Reproject to WGS 84 if needed if gdf.crs is not None and gdf.crs.to_epsg() != 4326: gdf = gdf.to_crs(epsg=4326) # Dissolve all features into a single geometry (union) dissolved = gdf.dissolve() if dissolved.empty or dissolved.geometry.iloc[0] is None: raise ValueError(f"AOI file is empty or has no geometry: {path}") geom = cast(Any, dissolved.geometry.iloc[0]) geom_json = geom.__geo_interface__ return ee.Geometry(geom_json) # ── GeoJSON dict branch ───────────────────────────────────────────────── geom_type = aoi.get("type", "") if geom_type == "FeatureCollection": return ee.FeatureCollection(aoi).geometry() elif geom_type == "Feature": return ee.Geometry(aoi["geometry"]) else: return ee.Geometry(aoi) def _resolve_sensor(sensor: str) -> str: """Resolve sensor aliases and validate the sensor name.""" sensor = _SENSOR_ALIASES.get(sensor, sensor) if sensor not in _ALL_VALID_SENSORS: raise ValueError( f"Unknown sensor {sensor!r}. " f"Valid options: {sorted(_ALL_VALID_SENSORS)}" ) return sensor def _build_single_sensor_collection( sensor: str, ee_geom: EEGeometry, start: str, end: str, max_cloud_cover: float, use_slc_off: bool, ) -> tuple[EEImageCollection, dict[str, str]]: """Build a cloud-masked, scaled collection for a single sensor. Returns ------- collection : ee.ImageCollection Pre-processed collection (cloud-masked, reflectance-scaled) with MNDWI / NDVI / NDTI index bands added. bands : dict Band-name mapping used (needed only for documentation; _add_indices has already been applied). """ band_family = _SENSOR_BAND_FAMILY[sensor] bm = _BAND_MAP[band_family] sf = _SCALE_FACTOR[band_family] cloud_prop = _CLOUD_COVER_PROP[band_family] # For Landsat 7 SLC-off filter effective_end = end if sensor == "Landsat7" and not use_slc_off: effective_end = min(end[:10], _L7_SLC_FAILURE_DATE) if start[:10] >= effective_end: raise RuntimeError( "Landsat 7 SLC-on data (before 2003-06-01) is not available " f"in the requested date range [{start}, {end}]. " "Set use_slc_off=True to include post-failure data." ) if end[:10] > _L7_SLC_FAILURE_DATE: warnings.warn( f"Landsat 7 SLC-off images after {_L7_SLC_FAILURE_DATE} " "are excluded (use_slc_off=False). " "Only the 1999-2003 good-quality record will be used.", UserWarning, stacklevel=4, ) collection = ee.ImageCollection(_COLLECTION_ID[sensor]) collection = collection.filterBounds(ee_geom).filterDate(start, effective_end) if sensor not in ("MODIS_Terra", "MODIS_Aqua"): collection = collection.filter(ee.Filter.lt(cloud_prop, max_cloud_cover)) if sensor == "Sentinel2": collection = collection.map(_mask_sentinel2_clouds) collection = collection.map( lambda img: ( img.multiply(sf["scale"]).copyProperties(img, ["system:time_start"]) ) ) elif sensor in ("MODIS_Terra", "MODIS_Aqua"): collection = collection.map(_mask_modis_clouds) collection = collection.map( lambda img: ( img.multiply(sf["scale"]).copyProperties(img, ["system:time_start"]) ) ) else: collection = collection.map(_mask_landsat_clouds) collection = collection.map( lambda img: ( img.multiply(sf["scale"]) .add(sf["offset"]) .copyProperties(img, ["system:time_start"]) ) ) collection = collection.map(lambda img: _add_indices(img, bm)) return collection, bm def _build_processed_collection( aoi: "dict | str", start: str, end: str, sensor: str, index: "str | Sequence[str]", custom_indices: dict[str, str] | None, max_cloud_cover: float, temporal_aggregation: str, use_slc_off: bool, climate_adaptive: bool, min_precip_mm: float, min_temp_c: float, hydroperiod_months: int, wetness_index: str, wetness_threshold: float, dem_mask: bool, max_slope_deg: float | None, max_tpi_m: float | None, tpi_window_px: int, max_local_range_m: float | None, local_range_window_px: int, max_elevation_m: float | None, months: list[int] | None = None, reduction_method: str = "median", percentile: float = 50.0, ) -> tuple[EEImageCollection, EEGeometry, list[str]]: """Build and process collection for fetch/fetch_xee with shared behavior.""" sensor = _resolve_sensor(sensor) reduction_method = _normalize_reduction_method(reduction_method) if reduction_method == "percentile": _validate_percentile(percentile) if custom_indices is None: custom_indices = {} if not isinstance(custom_indices, dict): raise TypeError("custom_indices must be a dict[str, str] or None.") for name, formula in custom_indices.items(): if not isinstance(name, str) or not name.strip(): raise ValueError("Each custom index name must be a non-empty string.") if not isinstance(formula, str) or not formula.strip(): raise ValueError( f"Formula for custom index {name!r} must be a non-empty string." ) reserved = set(custom_indices).intersection(_VALID_INDICES) if reserved: raise ValueError( "custom_indices cannot redefine built-in indices. " f"Reserved names: {sorted(reserved)}" ) indices_list = [index] if isinstance(index, str) else list(index) valid_indices = _VALID_INDICES.union(set(custom_indices)) bad = set(indices_list) - valid_indices if bad: raise ValueError( f"Unknown index/indices: {bad}. " f"Valid built-ins: {_VALID_INDICES}; custom: {set(custom_indices)}" ) if climate_adaptive and wetness_index not in indices_list: raise ValueError( "wetness_index must be included in index when climate_adaptive=True. " f"Got wetness_index={wetness_index!r}, index={indices_list!r}." ) ee_geom = _parse_aoi(aoi) sensor_bands: dict[str, str] # Build collection (merged or single sensor) if sensor == "LandsatAll": collection = _build_landsat_all(ee_geom, start, end, max_cloud_cover, use_slc_off) sensor_bands = _BAND_MAP["_harmonised"] elif sensor == "MODISAll": collection = _build_modis_all(ee_geom, start, end, max_cloud_cover) sensor_bands = _BAND_MAP["MODIS_500m"] elif sensor in ("MODIS_Terra", "MODIS_Aqua"): bm = _BAND_MAP["MODIS_500m"] sf = _SCALE_FACTOR["MODIS_500m"] raw = ( ee.ImageCollection(_COLLECTION_ID[sensor]) .filterBounds(ee_geom) .filterDate(start, end) ) raw = raw.map(_mask_modis_clouds) raw = raw.map( lambda img: ( img.multiply(sf["scale"]) .copyProperties(img, ["system:time_start"]) ) ) collection = raw.map(lambda img: _add_indices(img, bm)) sensor_bands = bm else: collection, sensor_bands = _build_single_sensor_collection( sensor, ee_geom, start, end, max_cloud_cover, use_slc_off ) if custom_indices: collection = collection.map( lambda img: _add_custom_indices(img, sensor_bands, custom_indices) ) # Server-side DEM terrain mask if dem_mask: terrain_mask = _build_dem_mask( ee_geom, max_slope_deg=max_slope_deg, max_tpi_m=max_tpi_m, tpi_window_px=tpi_window_px, max_local_range_m=max_local_range_m, local_range_window_px=local_range_window_px, max_elevation_m=max_elevation_m, ) collection = collection.map(lambda img: img.updateMask(terrain_mask).float()) # Keep only requested index bands collection = collection.select(indices_list) # Filter by months if requested (before compositing) if months is not None: if isinstance(months, int): months_list = [months] else: months_list = months if not months_list: raise ValueError("months must contain at least one month number when provided.") if any((not isinstance(month, int)) or month < 1 or month > 12 for month in months_list): raise ValueError(f"months must be integers from 1 to 12, got {months_list!r}.") filtered = ee.ImageCollection([]) for m in months_list: filtered = filtered.merge(collection.filter(ee.Filter.calendarRange(m, m, "month"))) collection = filtered # Server-side temporal compositing if climate_adaptive: monthly = _build_composites( collection, "monthly", start, end, indices_list, reduction_method, percentile ) collection = _build_climate_adaptive_composites( monthly, start=start, end=end, index_bands=indices_list, wetness_index=wetness_index, wetness_threshold=wetness_threshold, min_precip_mm=min_precip_mm, min_temp_c=min_temp_c, hydroperiod_months=hydroperiod_months, ) else: collection = _build_composites( collection, temporal_aggregation, start, end, indices_list, reduction_method, percentile ) return collection, ee_geom, indices_list def _ee_image_to_dataarray( image: EEImage, ee_geom: EEGeometry, scale: int, ) -> "xr.DataArray": """Download a single-band GEE image to a numpy-backed xr.DataArray. Uses GEE's ``getDownloadURL`` + rasterio. Returns a DataArray with dims ``(y, x)`` and CRS written if rioxarray is available. Raises ------ RuntimeError Re-raises any download or rasterio error with an informative message so callers can catch and fill with NaN. """ import os import tempfile import urllib.request import rasterio import xarray as xr try: import rioxarray # noqa: F401 _rio = True except ImportError: _rio = False try: url = image.getDownloadURL( { "scale": scale, "region": ee_geom, "format": "GEO_TIFF", "crs": "EPSG:4326", } ) except Exception as exc: raise RuntimeError(f"GEE getDownloadURL failed: {exc}") from exc with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp: tmp_path = tmp.name try: urllib.request.urlretrieve(url, tmp_path) with rasterio.open(tmp_path) as src: data = src.read(1).astype(float) if src.nodata is not None: data[data == src.nodata] = np.nan transform = src.transform crs = src.crs ny, nx = data.shape xs = [transform.c + transform.a * (j + 0.5) for j in range(nx)] ys = [transform.f + transform.e * (i + 0.5) for i in range(ny)] da = xr.DataArray(data, dims=["y", "x"], coords={"y": ys, "x": xs}) if _rio and crs is not None: da = da.rio.write_crs(crs.to_epsg() or str(crs)) return da except Exception as exc: raise RuntimeError(f"Download/read failed: {exc}") from exc finally: try: os.unlink(tmp_path) except OSError: pass def _build_dem_mask( ee_geom: EEGeometry, max_slope_deg: float | None = 5.0, max_tpi_m: float | None = None, tpi_window_px: int = 5, max_local_range_m: float | None = None, local_range_window_px: int = 5, max_elevation_m: float | None = None, ) -> EEImage: """Build a server-side terrain flatness mask using Copernicus GLO-30 DEM. Returns a binary mask image (1 = valid flat terrain, 0 = steep/high). Used by :func:`fetch` when ``dem_mask=True`` to suppress glacier and snowpack artefacts server-side before download. Parameters ---------- ee_geom : ee.Geometry Area of interest for DEM loading. max_slope_deg : float or None Maximum slope in degrees. Pixels steeper than this are masked. Default 5.0. max_tpi_m : float or None Maximum absolute TPI (metres). Uses a focal mean kernel. Default None (disabled). tpi_window_px : int Kernel radius in pixels for TPI focal mean. Default 5. max_local_range_m : float or None Maximum local elevation range (metres). Default None (disabled). local_range_window_px : int Kernel radius for local range. Default 5. max_elevation_m : float or None Absolute elevation ceiling (metres). Pixels above this are always masked. Default None (disabled). Returns ------- ee.Image Single-band mask: 1 = valid terrain, 0 = artefact. """ dem = ( ee.ImageCollection("COPERNICUS/DEM/GLO30") .filterBounds(ee_geom) .select("DEM") .mean() ) mask = ee.Image.constant(1) # ── Absolute elevation ceiling ────────────────────────────────────────── if max_elevation_m is not None: mask = mask.And(dem.lte(max_elevation_m)) # ── Slope ─────────────────────────────────────────────────────────────── if max_slope_deg is not None: slope = ee.Terrain.slope(dem) mask = mask.And(slope.lte(max_slope_deg)) # ── TPI ───────────────────────────────────────────────────────────────── if max_tpi_m is not None: kernel = ee.Kernel.square(tpi_window_px, "pixels") focal_mean = dem.reduceNeighborhood( reducer=ee.Reducer.mean(), kernel=kernel ) tpi = dem.subtract(focal_mean).abs() mask = mask.And(tpi.lte(max_tpi_m)) # ── Local elevation range ──────────────────────────────────────────────── if max_local_range_m is not None: kernel = ee.Kernel.square(local_range_window_px, "pixels") local_max = dem.reduceNeighborhood(ee.Reducer.max(), kernel) local_min = dem.reduceNeighborhood(ee.Reducer.min(), kernel) local_rng = local_max.subtract(local_min) mask = mask.And(local_rng.lte(max_local_range_m)) return mask.rename("terrain_mask") # --------------------------------------------------------------------------- # Public API: fetch() # ---------------------------------------------------------------------------
[docs] def fetch( aoi: "dict | str", start: str, end: str, sensor: str = "Landsat8", index: "str | Sequence[str]" = "MNDWI", custom_indices: dict[str, str] | None = None, scale: int = 30, max_cloud_cover: float = 20.0, temporal_aggregation: str = "all", use_slc_off: bool = False, project: str | None = None, climate_adaptive: bool = False, min_precip_mm: float = 20.0, min_temp_c: float = 5.0, hydroperiod_months: int = 1, wetness_index: str = "MNDWI", wetness_threshold: float = 0.0, dem_mask: bool = False, max_slope_deg: float | None = 5.0, max_tpi_m: float | None = None, tpi_window_px: int = 5, max_local_range_m: float | None = None, local_range_window_px: int = 5, max_elevation_m: float | None = None, months: list[int] | None = None, reduction_method: str = "median", percentile: float = 50.0, ) -> "xr.DataArray | xr.Dataset": """Retrieve spectral indices from GEE as an xarray object (immediate download). Fetches an image collection, applies cloud masking and surface-reflectance scaling, computes MNDWI / NDVI / NDTI server-side, optionally composites by time period, and downloads the result. Parameters ---------- aoi : dict, str, or Path Area of interest. Accepts: - **GeoJSON dict** — plain geometry, Feature, or FeatureCollection. - **Shapefile path** — ``str`` / ``Path`` to a ``.shp`` (or any format readable by ``geopandas``). Multiple features are dissolved into one boundary. Requires ``pip install geopandas``. - **GeoJSON file path** — a ``.geojson`` / ``.json`` file. start : str Start date ISO 8601, e.g. ``"2000-01-01"``. end : str End date (inclusive) ISO 8601, e.g. ``"2023-12-31"``. sensor : str Satellite sensor. One of: ``"Landsat4"``, ``"Landsat5"``, ``"Landsat7"``, ``"Landsat8"`` (default), ``"Landsat9"``, ``"LandsatAll"``, ``"Sentinel2"``, ``"MODIS_Terra"``, ``"MODIS_Aqua"``, ``"MODISAll"``. ``"Landsat"`` is an alias for ``"Landsat8"``. index : str or list of str One or more of ``"MNDWI"``, ``"NDWI"``, ``"NDVI"``, ``"NDTI"``, ``"AWEIsh"``, ``"AWEInsh"``. Single str → DataArray; list → Dataset (both with time dim). custom_indices : dict[str, str], optional User-defined index formulas evaluated server-side via ``ee.Image.expression``. Dictionary keys are output band names and values are expression strings. Available formula symbols are ``blue``, ``green``, ``red``, ``nir``, ``swir`` (alias ``swir1``), ``swir2``, and ``qa``. Example: ``{"NDSI": "(green - swir) / (green + swir)"}``. To request a custom index, include its name in ``index``. scale : int Spatial resolution in metres. Default 30 (Landsat native). max_cloud_cover : float Maximum cloud cover (%) per image. Default 20. temporal_aggregation : {"all", "annual", "monthly", "seasonal"} Server-side temporal compositing. Default ``"all"`` (every scene). Using ``"annual"`` or ``"monthly"`` greatly reduces download volume. use_slc_off : bool Include Landsat 7 SLC-off images (acquired after 2003-05-31)? Default ``False``. Only relevant when ``sensor`` is ``"Landsat7"`` or ``"LandsatAll"``. project : str, optional GEE cloud project ID. climate_adaptive : bool If ``True``, replace the standard temporal composite with a climate-adaptive annual composite guided by ERA5-Land precipitation and temperature. When enabled, ``temporal_aggregation`` is ignored (output is always one image per year). Default ``False``. months : list of int, optional Restrict compositing to these calendar months (1=Jan, ..., 12=Dec). Example: [6, 7, 8] for June–August only. Default None (all months). reduction_method : {"median", "mean", "percentile"} Collection reducer applied within each temporal aggregation window. Default ``"median"``. percentile : float Percentile to use when ``reduction_method="percentile"``. Must be between 0 and 100 inclusive. Default 50.0. When ``climate_adaptive=True``, the algorithm: 1. Builds monthly Landsat composites internally. 2. Joins with ERA5-Land monthly precipitation and 2m temperature. 3. Filters months where precip >= ``min_precip_mm`` AND temp >= ``min_temp_c`` (excludes dry season and snow months). 4. For each year, selects per-pixel values from the month with peak precipitation using ``qualityMosaic`` (captures maximum wetness rather than an arbitrary median). 5. Masks pixels wet for fewer than ``hydroperiod_months`` months per year on average (removes transient waterlogging). min_precip_mm : float Minimum monthly precipitation (mm) to include a month in the composite window. Used only when ``climate_adaptive=True``. Default 20 mm. min_temp_c : float Minimum monthly mean 2m temperature (degrees C) to include a month. Excludes frozen-ground and snow months. Used only when ``climate_adaptive=True``. Default 5 degrees C. hydroperiod_months : int Minimum number of months per year a pixel must be wet (index above ``wetness_threshold``) on average across the full record to be retained. Pixels below this are masked as transient waterlogging. Used only when ``climate_adaptive=True``. Default 1. Increase to 2-3 for stricter wetland delineation. wetness_index : str Which index band to use as the wetness indicator for hydroperiod counting and qualityMosaic selection. Must be one of the bands in ``index``. Default ``"MNDWI"``. wetness_threshold : float Index value above which a pixel is counted as wet for the hydroperiod calculation. Default 0.0. dem_mask : bool If ``True``, apply a server-side terrain flatness mask using the Copernicus GLO-30 DEM before compositing. Masks out glaciers, snowpacks, and steep mountain terrain that produce false wetness signals. Default ``False``. max_slope_deg : float or None Maximum terrain slope (degrees) to retain when ``dem_mask=True``. Default 5.0. Set to ``None`` to disable slope filtering. max_tpi_m : float or None Maximum absolute TPI (metres) when ``dem_mask=True``. Default ``None`` (disabled). tpi_window_px : int Focal window radius in pixels for TPI. Default 5. max_local_range_m : float or None Maximum local elevation range (metres) in the rolling window when ``dem_mask=True``. Default ``None`` (disabled). local_range_window_px : int Window radius for local elevation range. Default 5. max_elevation_m : float or None Absolute elevation ceiling (metres). Pixels above this elevation are always masked when ``dem_mask=True``. Default ``None``. Returns ------- xr.DataArray DataArray with dims ``(time, y, x)`` when a single index is requested. xr.Dataset Dataset with one variable per index, dims ``(time, y, x)``. Notes ----- Time steps where no valid cloud-free pixels exist are skipped with a ``UserWarning`` rather than raising an error. The returned object will have fewer time steps than requested periods in such cases. Examples -------- Long-record annual MNDWI for dynamics using all available Landsat missions: >>> mndwi = fetch(aoi, "1984-01-01", "2023-12-31", ... sensor="LandsatAll", ... temporal_aggregation="annual") >>> dynamics = classify_dynamics(mndwi, nYear=3) Post-monsoon WCT composite from Landsat 5 era: >>> indices = fetch(aoi, "2005-10-01", "2005-12-31", ... sensor="Landsat5", ... index=["MNDWI", "NDVI", "NDTI"]) Annual MNDWI from MODIS for regional-scale analysis: >>> mndwi_modis = fetch(aoi, "2000-01-01", "2023-12-31", ... sensor="MODISAll", ... temporal_aggregation="annual", ... scale=500) """ _require_ee() if temporal_aggregation not in _VALID_AGGREGATIONS: raise ValueError( f"temporal_aggregation must be one of {_VALID_AGGREGATIONS}, " f"got {temporal_aggregation!r}." ) try: ee.Number(1).getInfo() except Exception: init(project=project) collection, ee_geom, indices_list = _build_processed_collection( aoi=aoi, start=start, end=end, sensor=sensor, index=index, custom_indices=custom_indices, max_cloud_cover=max_cloud_cover, temporal_aggregation=temporal_aggregation, use_slc_off=use_slc_off, climate_adaptive=climate_adaptive, min_precip_mm=min_precip_mm, min_temp_c=min_temp_c, hydroperiod_months=hydroperiod_months, wetness_index=wetness_index, wetness_threshold=wetness_threshold, dem_mask=dem_mask, max_slope_deg=max_slope_deg, max_tpi_m=max_tpi_m, tpi_window_px=tpi_window_px, max_local_range_m=max_local_range_m, local_range_window_px=local_range_window_px, max_elevation_m=max_elevation_m, months=months, reduction_method=reduction_method, percentile=percentile, ) n_images_info = collection.size().getInfo() if n_images_info is None: raise RuntimeError("Could not determine the number of images returned by GEE.") n_images = int(n_images_info) if n_images == 0: raise RuntimeError( f"No images found for sensor={sensor!r}, [{start}, {end}], " f"max_cloud_cover={max_cloud_cover}%, " f"temporal_aggregation={temporal_aggregation!r}." ) import xarray as xr image_list = collection.toList(n_images) result_ds_list: list["xr.Dataset"] = [] skipped = 0 for i in range(n_images): img = ee.Image(image_list.get(i)) ts = img.get("system:time_start").getInfo() dt = np.datetime64(datetime.datetime.utcfromtimestamp(ts / 1000)) band_arrays: dict[str, "xr.DataArray"] = {} failed = False for idx in indices_list: try: da_band = _ee_image_to_dataarray(img.select(idx), ee_geom, scale) band_arrays[idx] = da_band except RuntimeError as exc: warnings.warn( f"Skipping time step {dt} (index '{idx}'): {exc}", UserWarning, stacklevel=2, ) failed = True break if failed: skipped += 1 continue ds_t = xr.Dataset({k: v.expand_dims(time=[dt]) for k, v in band_arrays.items()}) result_ds_list.append(ds_t) if not result_ds_list: raise RuntimeError( f"All {n_images} time steps failed to download. " "Check your AOI size, date range, and cloud cover threshold." ) if skipped: warnings.warn( f"{skipped} of {n_images} time step(s) were skipped due to " "download errors or empty composites.", UserWarning, stacklevel=2, ) combined = xr.concat(result_ds_list, dim="time") if isinstance(index, str): result = combined[index] result.name = index return result return combined
# --------------------------------------------------------------------------- # Public API: fetch_xee() # ---------------------------------------------------------------------------
[docs] def fetch_xee( aoi: "dict | str", start: str, end: str, sensor: str = "Landsat8", index: "str | Sequence[str]" = "MNDWI", custom_indices: dict[str, str] | None = None, scale: int = 30, max_cloud_cover: float = 20.0, temporal_aggregation: str = "all", use_slc_off: bool = False, project: str | None = None, climate_adaptive: bool = False, min_precip_mm: float = 20.0, min_temp_c: float = 5.0, hydroperiod_months: int = 1, wetness_index: str = "MNDWI", wetness_threshold: float = 0.0, dem_mask: bool = False, max_slope_deg: float | None = 5.0, max_tpi_m: float | None = None, tpi_window_px: int = 5, max_local_range_m: float | None = None, local_range_window_px: int = 5, max_elevation_m: float | None = None, chunks: dict | None = None, months: list[int] | None = None, reduction_method: str = "median", percentile: float = 50.0, ) -> "xr.DataArray | xr.Dataset": """Retrieve spectral indices from GEE as a lazy Dask-backed xarray via xee. Parameters ---------- aoi : dict, str, or Path GeoJSON dict, shapefile path, or GeoJSON file path. start, end : str ISO 8601 date strings. sensor : str Sensor key — see :func:`fetch` for the full list. Default ``"Landsat8"``. index : str or list of str One or more of ``"MNDWI"``, ``"NDWI"``, ``"NDVI"``, ``"NDTI"``, ``"AWEIsh"``, ``"AWEInsh"``. Single str → DataArray; list → Dataset. custom_indices : dict[str, str], optional User-defined index formulas evaluated server-side via ``ee.Image.expression``. Dictionary keys are output band names and values are expression strings. Available formula symbols are ``blue``, ``green``, ``red``, ``nir``, ``swir`` (alias ``swir1``), ``swir2``, and ``qa``. Example: ``{"NDSI": "(green - swir) / (green + swir)"}``. To request a custom index, include its name in ``index``. scale : int Pixel resolution in metres. Default 30. max_cloud_cover : float Maximum per-image cloud cover (%). Default 20. temporal_aggregation : {"all", "annual", "monthly", "seasonal"} Server-side compositing mode. Default ``"all"``. .. warning:: ``"all"`` passes the raw scene collection to xee with no compositing. Older xee versions may return integer time indices (0, 1, 2 …) instead of real timestamps for this mode. This function detects the issue, emits a ``UserWarning``, and patches the time coordinate automatically by querying ``system:time_start`` from GEE. Use ``"annual"`` or ``"monthly"`` to avoid the extra round-trip. use_slc_off : bool Include Landsat 7 SLC-off scenes. Default ``False``. project : str, optional GEE cloud project ID. climate_adaptive, min_precip_mm, min_temp_c, hydroperiod_months, wetness_index, wetness_threshold, dem_mask, max_slope_deg, max_tpi_m, tpi_window_px, max_local_range_m, local_range_window_px, max_elevation_m Same semantics as :func:`fetch`. chunks : dict, optional Dask chunk sizes, e.g. ``{"time": 1, "lon": 512, "lat": 512}``. months : list of int, optional Restrict compositing to these calendar months (1=Jan, ..., 12=Dec). Example: [6, 7, 8] for June–August only. Default None (all months). reduction_method : {"median", "mean", "percentile"} Collection reducer applied within each temporal aggregation window. Default ``"median"``. percentile : float Percentile to use when ``reduction_method="percentile"``. Must be between 0 and 100 inclusive. Default 50.0. Returns ------- xr.DataArray or xr.Dataset Lazy object with dims ``(time, lat, lon)``. Notes ----- After calling ``.compute()``, orient dimensions with:: da = da.rename({"lat": "y", "lon": "x"}).transpose("time", "y", "x") Examples -------- >>> mndwi_lazy = fetch_xee( ... aoi, "1984-01-01", "2023-12-31", ... sensor="LandsatAll", temporal_aggregation="annual", ... ) >>> mndwi = ( ... mndwi_lazy ... .rename({"lat": "y", "lon": "x"}) ... .transpose("time", "y", "x") ... .compute() ... ) """ _require_ee() try: import xee # noqa: F401 except ImportError: raise ImportError( "xee is required for fetch_xee(). " "Install: pip install 'wetlandmapper[gee]'" ) try: import dask # noqa: F401 except ImportError: raise ImportError("dask is required. Install: pip install dask") if temporal_aggregation not in _VALID_AGGREGATIONS: raise ValueError( f"temporal_aggregation must be one of {_VALID_AGGREGATIONS}, " f"got {temporal_aggregation!r}." ) try: ee.Number(1).getInfo() except Exception: init(project=project) collection, ee_geom, indices_list = _build_processed_collection( aoi=aoi, start=start, end=end, sensor=sensor, index=index, custom_indices=custom_indices, max_cloud_cover=max_cloud_cover, temporal_aggregation=temporal_aggregation, use_slc_off=use_slc_off, climate_adaptive=climate_adaptive, min_precip_mm=min_precip_mm, min_temp_c=min_temp_c, hydroperiod_months=hydroperiod_months, wetness_index=wetness_index, wetness_threshold=wetness_threshold, dem_mask=dem_mask, max_slope_deg=max_slope_deg, max_tpi_m=max_tpi_m, tpi_window_px=tpi_window_px, max_local_range_m=max_local_range_m, local_range_window_px=local_range_window_px, max_elevation_m=max_elevation_m, months=months, reduction_method=reduction_method, percentile=percentile, ) # xee requires a bounding box — arbitrary polygon → one-pixel bug bounds = cast(dict[str, Any], ee_geom.bounds().getInfo()) bounds_info = cast(list[list[float]], bounds["coordinates"][0]) lons = [c[0] for c in bounds_info] lats = [c[1] for c in bounds_info] ee_bbox = ee.Geometry.BBox(min(lons), min(lats), max(lons), max(lats)) import numpy as np import pandas as pd import xarray as xr projection = ee.Projection("EPSG:4326").atScale(scale) default_chunks = chunks or {"lon": 512, "lat": 512} open_chunks = {k: v for k, v in default_chunks.items() if k != "time"} post_time_chunk = default_chunks.get("time") ds_lazy = xr.open_dataset( cast(Any, collection), engine="ee", projection=projection, geometry=ee_bbox, chunks=open_chunks or None, ) # Avoid xee's warning about splitting stored time chunks at open time. # If the user asked for time chunking, apply it lazily after dataset creation. if post_time_chunk is not None and "time" in ds_lazy.dims: ds_lazy = ds_lazy.chunk({"time": post_time_chunk}) # ---------------------------------------------------------------- # Integer time-coordinate fix # # xee sometimes returns integer indices (0, 1, 2 …) instead of # real timestamps — a known bug with temporal_aggregation="all" and # some older xee versions. # # Fix strategy: # 1. Detect the issue via dtype check. # 2. Fetch real timestamps from GEE via aggregate_array(). # 3. Slice ds_lazy to the number of available timestamps (GEE may # return fewer images than xee allocated slots for), then # assign_coords. Slicing first avoids a shape-mismatch error # if len(ms_list) < ds_lazy.sizes["time"]. # ---------------------------------------------------------------- if np.issubdtype(ds_lazy["time"].dtype, np.integer): warnings.warn( "fetch_xee: xee returned integer time indices instead of " "real timestamps (known xee limitation with " "temporal_aggregation='all'). Fetching real timestamps from " "GEE and patching the time coordinate automatically.\n" "Tip: use temporal_aggregation='annual' or 'monthly' to avoid " "this extra round-trip.", UserWarning, stacklevel=2, ) try: ms_list = collection.aggregate_array("system:time_start").getInfo() except Exception as exc: warnings.warn( f"fetch_xee: Could not retrieve timestamps from GEE ({exc}). " "Time coordinate will remain as integer indices. " "Fix manually after .compute():\n" " da['time'] = pd.to_datetime(ms_list, unit='ms')", UserWarning, stacklevel=2, ) ms_list = None if ms_list is not None: # Number of valid timestamps may be less than xee's allocated # time slots — always take the minimum to avoid shape mismatch. n_available = min(ds_lazy.sizes.get("time", 0), len(ms_list)) real_times = pd.to_datetime(ms_list[:n_available], unit="ms") # Slice first, then assign — this is the correct order. ds_lazy = ds_lazy.isel(time=slice(n_available)).assign_coords( time=real_times ) # xee returns lat ascending (south → north); sort descending for # standard raster orientation. sortby is lazy. if "lat" in ds_lazy.dims: ds_lazy = ds_lazy.sortby("lat", ascending=False) if isinstance(index, str): da = ds_lazy[index] da.name = index return da return ds_lazy[indices_list]
# --------------------------------------------------------------------------- # Dependency guard # --------------------------------------------------------------------------- def _require_ee() -> None: if not _HAS_EE: raise ImportError( "earthengine-api is required for the GEE module.\n" "Install: pip install 'wetlandmapper[gee]'\n" "Auth: earthengine authenticate" ) # --------------------------------------------------------------------------- # Climate-adaptive annual compositing (ERA5-Land guided) # --------------------------------------------------------------------------- def _build_climate_adaptive_composites( collection: EEImageCollection, start: str, end: str, index_bands: list[str], wetness_index: str = "MNDWI", wetness_threshold: float = 0.0, min_precip_mm: float = 20.0, min_temp_c: float = 5.0, hydroperiod_months: int = 1, ) -> EEImageCollection: """Build climate-adaptive annual composites guided by ERA5-Land. This function addresses two limitations of a naive annual median: 1. **Season selection**: Instead of compositing all months equally, it identifies for each year the month with peak precipitation that also meets minimum temperature and rainfall thresholds. Using ``qualityMosaic`` on precipitation selects, per pixel, the index value from the wettest climatically-valid month. This avoids selecting snow-covered or drought-period images as representative of annual wetness. 2. **Hydroperiod filtering**: Transient waterlogging (e.g. flooded fields after a storm) produces a water signal for only one or two months per year. A true wetland is inundated for a sustained period. Pixels that are wet during fewer than ``hydroperiod_months`` months per year on average across the full record are masked out. Parameters ---------- collection : ee.ImageCollection Pre-processed monthly composite collection with index bands (e.g. MNDWI, NDVI, NDTI) already computed server-side. Should be a monthly composite (``temporal_aggregation="monthly"`` applied upstream). start, end : str ISO 8601 date strings for the full requested range. index_bands : list of str Band names to composite (e.g. ``["MNDWI"]``). wetness_index : str Which band to use as the wetness indicator for both the ``qualityMosaic`` quality band and the hydroperiod count. Default ``"MNDWI"``. Can also be ``"AWEIsh"`` or ``"AWEInsh"`` if those bands were computed server-side. wetness_threshold : float Index value above which a pixel is considered wet for the hydroperiod count. Default 0.0 (standard MNDWI water threshold). min_precip_mm : float Minimum monthly total precipitation (mm) for a month to be included in the composite window. Months drier than this are skipped (dry season filter). Default 20 mm. min_temp_c : float Minimum monthly mean 2m air temperature (degrees C) for a month to be included. Months colder than this are skipped (snow/ice filter). Default 5 degrees C. ERA5-Land temperature is in Kelvin internally; this parameter is in Celsius for user convenience. hydroperiod_months : int Minimum number of months per year (on average across the full record) that a pixel must be wet to be retained as a wetland pixel. Pixels below this threshold are masked — they represent transient waterlogging rather than persistent wetland. Default 1. Increase to 2 or 3 for stricter wetland delineation. Returns ------- ee.ImageCollection One image per year, ``system:time_start`` set to July 1 of each year. Pixels failing the hydroperiod test are masked in all images. Notes ----- ERA5-Land data - Collection: ``ECMWF/ERA5_LAND/MONTHLY_AGGR`` - ``total_precipitation_sum``: monthly total precipitation in metres (multiplied by 1000 to convert to mm). - ``temperature_2m``: monthly mean 2m air temperature in Kelvin (273.15 subtracted to convert to degrees C). - Available from 1950-01-01 to near-present at 0.1 degree (~11 km) resolution; GEE resamples to the Landsat grid. Snow exclusion ERA5-Land ``total_precipitation_sum`` includes both rainfall and snowfall. The ``min_temp_c`` filter effectively excludes months where precipitation falls primarily as snow, without requiring a separate snowfall band. Hydroperiod vs wet_percent_threshold The ``hydroperiod_months`` parameter operates at monthly resolution and is applied during data acquisition. The ``thresholdWet`` parameter in :func:`classify_dynamics` operates at annual-composite resolution and is applied during classification. Both can be used together for a two-stage filter. Examples -------- Dryland application with strict hydroperiod: >>> mndwi = fetch( ... aoi, "1984-01-01", "2023-12-31", ... sensor="LandsatAll", ... climate_adaptive=True, ... min_precip_mm=25.0, # wet season only ... min_temp_c=10.0, # exclude cold months ... hydroperiod_months=2, # at least 2 wet months/year ... ) Temperate wetland (less strict): >>> mndwi = fetch( ... aoi, "2000-01-01", "2023-12-31", ... sensor="LandsatAll", ... climate_adaptive=True, ... min_precip_mm=10.0, ... min_temp_c=2.0, ... hydroperiod_months=1, ... ) """ start_yr = datetime.date.fromisoformat(start[:10]).year end_yr = datetime.date.fromisoformat(end[:10]).year # ── ERA5-Land monthly climate data ────────────────────────────────────── # Convert units server-side: # precipitation: m -> mm (*1000) # temperature: K -> C (-273.15) era5 = ( ee.ImageCollection("ECMWF/ERA5_LAND/MONTHLY_AGGR") .filterDate(start, end) .select( ["total_precipitation_sum", "temperature_2m"], ["precip_m", "temp_k"], ) ) # Add year, month, and precip_mm / temp_c as image properties + bands def _prepare_era5(img): date = img.date() yr = date.get("year") mo = date.get("month") precip = img.select("precip_m").multiply(1000) # m -> mm temp = img.select("temp_k").subtract(273.15) # K -> C return ( img .addBands(precip.rename("precip_mm")) .addBands(temp.rename("temp_c")) .set("year", yr) .set("month", mo) .set("ym", ee.String(yr).cat("_").cat( ee.Number(mo).format("%02d"))) ) era5 = era5.map(_prepare_era5) # ── Add year/month labels to the Landsat monthly composites ───────────── def _label_landsat(img): date = img.date() yr = date.get("year") mo = date.get("month") return img.set( "year", yr, "month", mo, "ym", ee.String(yr).cat("_").cat( ee.Number(mo).format("%02d")), ) collection = collection.map(_label_landsat) # ── Inner join on year-month string ───────────────────────────────────── join_filter = ee.Filter.equals(leftField="ym", rightField="ym") inner_join = ee.Join.inner("landsat", "era5") joined = inner_join.apply(collection, era5, join_filter) # Merge each pair into one image carrying both Landsat indices and # climate bands, then apply climate filters def _merge_pair(feature): ls = ee.Image(feature.get("landsat")) clm = ee.Image(feature.get("era5")) return ( ee.Image.cat(ls, clm.select(["precip_mm", "temp_c"])) .copyProperties(ls, ls.propertyNames()) ) joined_col = ee.ImageCollection(joined.map(_merge_pair)) # Filter to climate-valid months: warm enough AND wet enough climate_valid = joined_col.filter( ee.Filter.And( ee.Filter.gte("precip_mm", min_precip_mm), # server-side property? # Property-level filter won't work for raster values; # use a pixel-level mask instead (applied per image below) ) ) # Apply pixel-level climate mask (ERA5 at ~11 km resamples to Landsat grid) def _apply_climate_mask(img): precip_ok = img.select("precip_mm").gte(min_precip_mm) temp_ok = img.select("temp_c").gte(min_temp_c) valid = precip_ok.And(temp_ok) return img.updateMask(valid).copyProperties(img, img.propertyNames()) climate_valid = joined_col.map(_apply_climate_mask) # ── Hydroperiod mask ───────────────────────────────────────────────────── # For each year, count how many climate-valid months each pixel is wet. # Average across years. Mask pixels below hydroperiod_months. years = ee.List.sequence(start_yr, end_yr) def _wet_months_in_year(yr): yr_col = climate_valid.filter(ee.Filter.eq("year", yr)) wet_col = yr_col.map( lambda img: img.select(wetness_index) .gt(wetness_threshold) .rename("wet") .unmask(0) ) return wet_col.sum().rename("wet_months") wet_per_year = ee.ImageCollection(years.map(_wet_months_in_year)) mean_wet_mths = wet_per_year.mean() hydro_mask = mean_wet_mths.gte(hydroperiod_months) # ── Per-year best-month composite via qualityMosaic on precipitation ──── # qualityMosaic picks, per pixel, the values from the image with the # highest value of the quality band (here: precip_mm). This selects the # index values from the wettest climate-valid month of each year. def _annual_composite(yr): yr_col = climate_valid.filter(ee.Filter.eq("year", yr)) # qualityMosaic on precip_mm: highest precipitation month wins composite = yr_col.qualityMosaic("precip_mm") # Keep only the requested index bands composite = ( composite.select(index_bands) .updateMask(hydro_mask) # apply hydroperiod mask .float() # ensure consistent band type .set("system:time_start", ee.Date.fromYMD(yr, 7, 1).millis()) .set("year", yr) ) fallback = _make_nan_image(index_bands, ee.Date.fromYMD(yr, 7, 1).millis()) return ee.Image( ee.Algorithms.If(yr_col.size().gt(0), composite, fallback) ) images = years.map(_annual_composite) return ee.ImageCollection(images)