Source code for xoak.tree_adapters
from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
import numpy as np
try:
from xarray.indexes.nd_point_index import TreeAdapter # type: ignore
except ImportError:
class TreeAdapter: ...
if TYPE_CHECKING:
import pys2index
import sklearn.neighbors
[docs]
class S2PointTreeAdapter(TreeAdapter):
""":py:class:`pys2index.S2PointIndex` adapter for :py:class:`~xarray.indexes.NDPointIndex`."""
_s2point_index: pys2index.S2PointIndex
[docs]
def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from pys2index import S2PointIndex
self._s2point_index = S2PointIndex(points)
def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._s2point_index.query(points)
def equals(self, other: S2PointTreeAdapter) -> bool:
return np.array_equal(
self._s2point_index.get_cell_ids(), other._s2point_index.get_cell_ids()
)
[docs]
class SklearnKDTreeAdapter(TreeAdapter):
""":py:class:`sklearn.neighbors.KDTree` adapter for :py:class:`~xarray.indexes.NDPointIndex`."""
_kdtree: sklearn.neighbors.KDTree
[docs]
def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from sklearn.neighbors import KDTree
self._kdtree = KDTree(points, **options)
def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._kdtree.query(points)
def equals(self, other: SklearnKDTreeAdapter) -> bool:
return np.array_equal(self._kdtree.data, other._kdtree.data)
[docs]
class SklearnBallTreeAdapter(TreeAdapter):
""":py:class:`sklearn.neighbors.BallTree` adapter for
:py:class:`~xarray.indexes.NDPointIndex`.
"""
_balltree: sklearn.neighbors.BallTree
[docs]
def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from sklearn.neighbors import BallTree
self._balltree = BallTree(points, **options)
def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._balltree.query(points)
def equals(self, other: SklearnBallTreeAdapter) -> bool:
return np.array_equal(self._balltree.data, other._balltree.data)
[docs]
class SklearnGeoBallTreeAdapter(TreeAdapter):
""":py:class:`sklearn.neighbors.BallTree` adapter for
:py:class:`~xarray.indexes.NDPointIndex`, using the 'haversine' metric.
It can be used for indexing a set of latitude / longitude points.
When building the index, the coordinates must be given in the latitude,
longitude order.
Latitude and longitude values must be given in degrees for both index and
query points (those values are converted in radians by this adapter).
"""
_balltree: sklearn.neighbors.BallTree
[docs]
def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from sklearn.neighbors import BallTree
opts = dict(options)
opts.update({"metric": "haversine"})
self._balltree = BallTree(np.deg2rad(points), **opts)
def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._balltree.query(np.deg2rad(points))
def equals(self, other: SklearnGeoBallTreeAdapter) -> bool:
return np.array_equal(self._balltree.data, other._balltree.data)