"""
plotting.py
# Copyright (c) 2026, Manudeo Singh #
# Author: Manudeo Singh, March 2026 #
-----------
Convenience functions for visualising WetlandMapper outputs.
All functions:
- Respect geographic coordinates when present (``x``/``y`` or ``lon``/``lat``),
so axes ticks show real coordinate values rather than pixel indices.
- Detect y-axis direction (ascending vs descending) and set ``origin``
accordingly so images are never vertically flipped.
- Place the class legend outside the plot area to avoid obscuring data.
- Return (fig, ax) so callers can further customise or save.
Coordinate conventions
----------------------
:func:`wetlandmapper.gee.fetch`
Returns arrays with dims ``(time, y, x)`` where ``y`` is **descending**
(north → south, standard raster convention).
:func:`wetlandmapper.gee.fetch_xee`
After the built-in ``sortby`` fix, also returns ``y`` descending.
If you have an older array with ``lat``/``lon`` dims, rename them first:
``da = da.rename({"lat": "y", "lon": "x"})``.
"""
from __future__ import annotations
import numpy as np
__all__ = [
"plot_dynamics",
"plot_wct",
"plot_index",
"plot_wet_frequency",
]
# ---------------------------------------------------------------------------
# Lazy matplotlib import
# ---------------------------------------------------------------------------
def _get_mpl():
try:
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
return plt, mcolors, mpatches
except ImportError:
raise ImportError(
"matplotlib is required for plotting. "
"Install with: pip install matplotlib"
)
# ---------------------------------------------------------------------------
# Coordinate helpers
# ---------------------------------------------------------------------------
def _spatial_coords(da):
"""Return (x_values, y_values) arrays from a DataArray, or (None, None)."""
# Support both (y, x) and (lat, lon) naming conventions
x_coord = None
y_coord = None
for xname in ("x", "lon", "longitude"):
if xname in da.coords:
x_coord = da.coords[xname].values
break
for yname in ("y", "lat", "latitude"):
if yname in da.coords:
y_coord = da.coords[yname].values
break
return x_coord, y_coord
def _imshow_extent(da):
"""Compute the (left, right, bottom, top) extent for imshow from coords.
Returns None if no spatial coordinates are present, in which case imshow
falls back to pixel-index axes.
"""
x, y = _spatial_coords(da)
if x is None or y is None:
return None
# Half-pixel expansion so the extent covers the full pixel area
dx = float(np.abs(x[1] - x[0])) / 2 if len(x) > 1 else 0
dy = float(np.abs(y[1] - y[0])) / 2 if len(y) > 1 else 0
left = float(x.min()) - dx
right = float(x.max()) + dx
bottom = float(y.min()) - dy
top = float(y.max()) + dy
return (left, right, bottom, top)
def _imshow_origin(da):
"""Return 'upper' if y is descending (north→south), else 'lower'."""
_, y = _spatial_coords(da)
if y is not None and len(y) > 1:
return "upper" if y[0] > y[-1] else "lower"
return "upper" # safe default for plain numpy arrays
def _get_2d(da):
"""Squeeze or select to get a 2-D array for imshow."""
if "time" in da.dims:
da = da.isel(time=0)
# Drop any remaining length-1 dims
for dim in list(da.dims):
if dim not in ("y", "x", "lat", "lon", "latitude", "longitude"):
if da.sizes[dim] == 1:
da = da.isel({dim: 0})
return da
# ---------------------------------------------------------------------------
# Colormap / norm helpers
# ---------------------------------------------------------------------------
def _build_cmap_and_norm(class_codes, class_colors):
"""Build a discrete Matplotlib colormap from class code → color dicts."""
_, mcolors, _ = _get_mpl()
codes = sorted(class_codes.keys())
colors = [class_colors[c] for c in codes]
cmap = mcolors.ListedColormap(colors)
bounds = [codes[0] - 0.5] + [c + 0.5 for c in codes]
norm = mcolors.BoundaryNorm(bounds, cmap.N)
return cmap, norm, codes
# ---------------------------------------------------------------------------
# Legend helpers
# ---------------------------------------------------------------------------
def _add_outside_legend(fig, ax, patches, title, legend_loc):
"""Add a patch legend either inside (loc string)
or outside ('outside right' / 'outside bottom')."""
_, _, mpatches = _get_mpl()
if legend_loc == "outside right":
ax.legend(
handles=patches,
title=title,
fontsize=8,
title_fontsize=8.5,
framealpha=0.9,
loc="upper left",
bbox_to_anchor=(1.02, 1),
borderaxespad=0,
)
fig.tight_layout(rect=[0, 0, 0.80, 1])
elif legend_loc == "outside bottom":
n = len(patches)
ncol = min(n, 4)
ax.legend(
handles=patches,
title=title,
fontsize=8,
title_fontsize=8.5,
framealpha=0.9,
loc="upper center",
bbox_to_anchor=(0.5, -0.12),
ncol=ncol,
borderaxespad=0,
)
fig.tight_layout(rect=[0, 0.12, 1, 1])
else:
# Standard matplotlib loc string
ax.legend(
handles=patches,
title=title,
fontsize=8,
title_fontsize=8.5,
framealpha=0.9,
loc=legend_loc,
)
fig.tight_layout()
# ---------------------------------------------------------------------------
# Public plotting functions
# ---------------------------------------------------------------------------
[docs]
def plot_dynamics(
dynamics,
ax=None,
title: str = "Wetland Dynamics",
figsize: tuple = (8, 7),
add_colorbar: bool = True,
legend_loc: str = "outside right",
savepath: str | None = None,
dpi: int = 150,
):
"""Plot a wetland dynamics classification raster.
Parameters
----------
dynamics : xr.DataArray
Output of :func:`wetlandmapper.classify_dynamics`. Spatial coordinates
(``x``/``y`` or ``lon``/``lat``) are used for axis ticks when present.
ax : matplotlib.axes.Axes, optional
Axes to draw into. Created if not provided.
title : str
Plot title.
figsize : tuple
Figure size in inches.
add_colorbar : bool
Add a class legend.
legend_loc : str
Legend placement. Use ``"outside right"`` (default), ``"outside bottom"``,
or any standard Matplotlib ``loc`` string (e.g. ``"lower right"``).
``"outside right"`` / ``"outside bottom"`` never overlap the data.
savepath : str, optional
If given, save the figure to this path (PNG, PDF, TIFF, etc.).
dpi : int
Resolution for saved figure. Default 150.
Returns
-------
fig, ax : matplotlib Figure and Axes
"""
from .dynamics import DYNAMICS_CLASSES, DYNAMICS_COLORS
plt, mcolors, mpatches = _get_mpl()
cmap, norm, codes = _build_cmap_and_norm(DYNAMICS_CLASSES, DYNAMICS_COLORS)
fig, ax = _ensure_axes(ax, figsize)
da2d = _get_2d(dynamics)
extent = _imshow_extent(da2d)
origin = _imshow_origin(da2d)
ax.imshow(
da2d.values,
cmap=cmap,
norm=norm,
origin=origin,
extent=extent,
interpolation="nearest",
aspect="equal" if extent is None else "auto",
)
_add_xy_labels(ax, da2d)
ax.set_title(title, fontsize=12, fontweight="bold", pad=6)
if add_colorbar:
patches = [
mpatches.Patch(color=DYNAMICS_COLORS[c], label=DYNAMICS_CLASSES[c])
for c in sorted(DYNAMICS_CLASSES.keys(), reverse=True)
]
_add_outside_legend(fig, ax, patches, "Dynamics Class", legend_loc)
else:
fig.tight_layout()
if savepath:
fig.savefig(savepath, dpi=dpi, bbox_inches="tight")
return fig, ax
[docs]
def plot_wct(
wct,
ax=None,
title: str = "Wetland Cover Types",
figsize: tuple = (8, 7),
add_colorbar: bool = True,
legend_loc: str = "outside right",
savepath: str | None = None,
dpi: int = 150,
):
"""Plot a Wetland Cover Type classification raster.
Parameters
----------
wct : xr.DataArray
Output of :func:`wetlandmapper.classify_wct` or
:func:`wetlandmapper.classify_wct_ema`.
ax, title, figsize, add_colorbar, legend_loc, savepath, dpi
Same as :func:`plot_dynamics`.
Returns
-------
fig, ax
"""
from .wct import WCT_CLASSES, WCT_COLORS
plt, mcolors, mpatches = _get_mpl()
cmap, norm, codes = _build_cmap_and_norm(WCT_CLASSES, WCT_COLORS)
fig, ax = _ensure_axes(ax, figsize)
da2d = _get_2d(wct)
extent = _imshow_extent(da2d)
origin = _imshow_origin(da2d)
ax.imshow(
da2d.values,
cmap=cmap,
norm=norm,
origin=origin,
extent=extent,
interpolation="nearest",
aspect="equal" if extent is None else "auto",
)
_add_xy_labels(ax, da2d)
ax.set_title(title, fontsize=12, fontweight="bold", pad=6)
if add_colorbar:
ordered = [c for c in sorted(WCT_CLASSES.keys()) if c != 0] + [0]
patches = [
mpatches.Patch(color=WCT_COLORS[c], label=WCT_CLASSES[c]) for c in ordered
]
_add_outside_legend(fig, ax, patches, "Cover Type", legend_loc)
else:
fig.tight_layout()
if savepath:
fig.savefig(savepath, dpi=dpi, bbox_inches="tight")
return fig, ax
[docs]
def plot_index(
da,
index_name: str = "Index",
ax=None,
figsize: tuple = (8, 7),
vmin: float = -1.0,
vmax: float = 1.0,
cmap: str = "RdYlGn",
time_step: int | None = None,
savepath: str | None = None,
dpi: int = 150,
):
"""Plot a single spectral index (MNDWI, NDVI, or NDTI).
Parameters
----------
da : xr.DataArray
2-D or 3-D (time, y, x) index DataArray.
index_name : str
Used in the title and colorbar label.
time_step : int, optional
Select this time index (0-based) when ``da`` has a time dim.
Defaults to the temporal mean if not provided.
savepath, dpi
Same as :func:`plot_dynamics`.
Returns
-------
fig, ax
"""
plt, _, _ = _get_mpl()
fig, ax = _ensure_axes(ax, figsize)
if "time" in da.dims:
if time_step is not None:
da2d = da.isel(time=time_step)
subtitle = f"t={time_step}"
else:
da2d = da.mean(dim="time")
subtitle = "temporal mean"
title_full = f"{index_name} ({subtitle})"
else:
da2d = da
title_full = index_name
extent = _imshow_extent(da2d)
origin = _imshow_origin(da2d)
im = ax.imshow(
da2d.values,
cmap=cmap,
vmin=vmin,
vmax=vmax,
origin=origin,
extent=extent,
interpolation="bilinear",
aspect="equal" if extent is None else "auto",
)
plt.colorbar(im, ax=ax, label=index_name, shrink=0.75, pad=0.02)
_add_xy_labels(ax, da2d)
ax.set_title(title_full, fontsize=12, fontweight="bold", pad=6)
fig.tight_layout()
if savepath:
fig.savefig(savepath, dpi=dpi, bbox_inches="tight")
return fig, ax
[docs]
def plot_wet_frequency(
mndwi,
ax=None,
figsize: tuple = (8, 7),
mndwi_threshold: float = 0.0,
savepath: str | None = None,
dpi: int = 150,
):
"""Plot wet frequency (%) derived from an MNDWI time series.
Parameters
----------
mndwi : xr.DataArray
Multi-temporal MNDWI with a ``time`` dimension.
mndwi_threshold : float
Pixels with MNDWI above this value are counted as wet.
savepath, dpi
Same as :func:`plot_dynamics`.
Returns
-------
fig, ax
"""
from .dynamics import compute_wet_frequency
plt, _, _ = _get_mpl()
freq = compute_wet_frequency(mndwi, mndwi_threshold=mndwi_threshold)
fig, ax = _ensure_axes(ax, figsize)
da2d = _get_2d(freq)
extent = _imshow_extent(da2d)
origin = _imshow_origin(da2d)
im = ax.imshow(
da2d.values,
cmap="Blues",
vmin=0,
vmax=100,
origin=origin,
extent=extent,
interpolation="bilinear",
aspect="equal" if extent is None else "auto",
)
plt.colorbar(im, ax=ax, label="Wet Frequency (%)", shrink=0.75, pad=0.02)
_add_xy_labels(ax, da2d)
ax.set_title("Wet Frequency (%)", fontsize=12, fontweight="bold", pad=6)
fig.tight_layout()
if savepath:
fig.savefig(savepath, dpi=dpi, bbox_inches="tight")
return fig, ax
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _ensure_axes(ax, figsize):
plt, _, _ = _get_mpl()
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
return fig, ax
def _add_xy_labels(ax, da):
x, y = _spatial_coords(da)
if x is not None:
ax.set_xlabel("Longitude", fontsize=9)
if y is not None:
ax.set_ylabel("Latitude", fontsize=9)
# Rotate x tick labels if they look like decimal degrees
if x is not None:
ax.tick_params(axis="x", labelsize=8, rotation=30)
if y is not None:
ax.tick_params(axis="y", labelsize=8)