Grass format example implementation#

For commentary on this implementation, see Supporting additional formats.

# > imports
import enum
from functools import cached_property
from typing import Optional, Tuple, Dict

import numpy as np
import xarray as xr
from shapely.geometry import Polygon
from shapely.geometry.base import BaseGeometry

from emsarray.formats import Format, Specificity
from emsarray.masking import blur_mask
from emsarray.types import Pathish
from emsarray.utils import linearise_dimensions
# <


class GrassGridKind(enum.Enum):
    blade = 'blade'
    meadow = 'meadow'


GrassIndex = Tuple[GrassGridKind, int, int]


class Grass(Format[GrassGridKind, GrassIndex]):

    #: All the grid kinds this dataset has
    grid_kinds = frozenset(GrassGridKind)

    #: Indicates the grid kind of cells
    default_grid_kind = GrassGridKind.blade

    @classmethod
    def check_dataset(cls, dataset: xr.Dataset) -> Optional[int]:
        # A Grass dataset is recognised by the 'Conventions' global attribute
        if dataset.attrs['Conventions'] == 'Grass 1.0':
            return Specificity.HIGH
        return None

    def ravel_index(self, index: GrassIndex) -> int:
        """Make a linear index from a native index"""
        kind, warp, weft = index
        # Meadows indexes are transposed from blade indexes
        if kind is GrassGridKind.meadow:
            return warp * self.dataset.dims['weft'] + weft
        else:
            return weft * self.dataset.dims['warp'] + warp

    def unravel_index(
        self,
        index: int,
        grid_kind: Optional[GrassGridKind] = None,
    ) -> GrassIndex:
        """Make a native index from a linear index"""
        grid_kind = grid_kind or self.default_grid_kind
        if grid_kind is GrassGridKind.meadow:
            warp, weft = divmod(index, self.dataset.dims['weft'])
        else:
            weft, warp = divmod(index, self.dataset.dims['warp'])
        return (grid_kind, warp, weft)

    def get_grid_kind_and_size(
        self, data_array: xr.DataArray,
    ) -> Tuple[GrassGridKind, int]:
        """
        For the given DataArray from this Dataset,
        find out what kind of grid it is, and the linear size of that grid.
        """
        size = self.dataset.dims['warp'] * self.dataset.dims['weft']
        if data_array.dims[-2:] == ('warp', 'weft'):
            return GrassGridKind.meadow, size
        if data_array.dims[-2:] == ('weft', 'warp'):
            return GrassGridKind.blade, size
        raise ValueError(
            "DataArray does not appear to be either a blade or meadow grid")

    def make_linear(self, data_array: xr.DataArray) -> xr.DataArray:
        """
        Make the given DataArray linear in its grid dimensions.
        """
        grid_kind, size = self.get_grid_kind_and_size(data_array)
        if grid_kind is GrassGridKind.meadow:
            dimensions = ['warp', 'weft']
        else:
            dimensions = ['weft', 'warp']
        return linearise_dimensions(data_array, dimensions)

    def selector_for_index(self, index: GrassIndex) -> Dict[str, int]:
        """
        Make a selector for a particular index.
        This selector can be passed to Dataset.isel().
        """
        kind, warp, weft = index
        return {'warp': warp, 'weft': weft}

    @cached_property
    def polygons(self) -> np.ndarray:
        def make_polygon_for_cell(warp: int, weft: int) -> Polygon:
            # Implementation left as an exercise for the reader
            return Polygon(...)

        return np.array([
            make_polygon_for_cell(warp, weft)
            for warp in range(self.dataset.dimensions['warp'])
            for weft in range(self.dataset.dimensions['weft'])
        ])

    def make_clip_mask(
        self,
        clip_geometry: BaseGeometry,
        buffer: int = 0,
    ) -> xr.Dataset:
        # Find all the blades that intersect the clip geometry
        intersecting_blades = [
            item
            for item, polygon in self.spatial_index.query(clip_geometry)
            if polygon.intersects(clip_geometry)
        ]
        # Get all the linear indexes of the intersecting blades
        blade_indexes = np.array([i.linear_index for i in intersecting_blades])
        # Find all the meadows associated with each intesecting blade
        meadow_indexes = np.unique([
            self.ravel_index(blade_index)
            for item in intersecting_blades
            for blade_index in self.get_meadows_for_blade(item.index)
        ])

        warp = self.dataset.dims['warp']
        weft = self.dataset.dims['weft']

        # Make a 2d array of which blades to keep
        keep_blades = np.zeros((weft, warp), dtype=bool)
        keep_blades.ravel()[blade_indexes] = True

        # Same for meadows
        keep_meadows = np.zeros((warp, weft), dtype=bool)
        keep_meadows.ravel()[meadow_indexes] = True

        # Blur the masks a bit if the clip region needs buffering
        if buffer > 0:
            keep_blades = blur_mask(keep_blades, size=buffer)
            keep_meadows = blur_mask(keep_meadows, size=buffer)

        # Make a dataset out of these masks
        return xr.Dataset(
            data_vars={
                'blades': xr.DataArray(data=keep_blades, dims=['weft', 'warp']),
                'meadows': xr.DataArray(data=keep_meadows, dims=['warp', 'weft']),
            },
        )

    def apply_clip_mask(self, clip_mask: xr.Dataset, work_dir: Pathish) -> xr.Dataset:
        # You're on your own, here.
        # This depends entirely on how the mask and datasets interact.
        pass