Source code for sklearn_xarray.target

"""``sklearn_xarray.target``"""

import numpy as np


[docs]class Target(object): """ A pointer to xarray coordinates or variables to be used as a target. Parameters ---------- coord : str, optional The coordinate or variable that holds the data of the target. If not specified, the target will be the entire DataArray/Dataset. transform_func : callable, optional A function that transforms the coordinate values to an sklearn-compatible type and shape. If not specified, the coordinate(s) will be used as-is. transformer : sklearn transformer, optional **Deprecated**, use ``transform_func=Transformer().fit_transform`` instead. lazy : bool, optinonal If true, the target coordinate is only transformed by the transformer when needed. The transformer can implement a ``get_transformed_shape`` method that returns the shape after the transformation of the provided coordinate without actually transforming the data. dim : str or sequence of str, optional When set, multi-dimensional coordinates will be reduced to this dimension/these dimensions. reduce_func : callable, optional A callable that reduces the coordinate(s) to the dimension(s) in ``dim``. If not specified, the values along dimensions not in ``dim`` will be reduced to the first element in each of these dimensions. """ def __init__( self, coord=None, transform_func=None, transformer=None, lazy=False, dim=None, reduce_func=None, ): self.transform_func = transform_func self.coord = coord self.lazy = lazy self.reduce_func = reduce_func self.dim = dim self.transformer = transformer if transformer is not None: import warnings warnings.simplefilter("always", DeprecationWarning) warnings.warn( "The transformer argument is deprecated and will be removed " "in a future version. Use " "transform_func=Transformer().fit_transform instead.", DeprecationWarning, ) warnings.simplefilter("ignore", DeprecationWarning) self.transform_func = self.transformer.fit_transform self.values = None def __getitem__(self, key): import copy self._check_assigned() new_obj = copy.copy(self) if self.lazy: new_obj.values = self.transform_func(self.values)[key] new_obj.lazy = False else: new_obj.values = self.values[key] return new_obj def __call__(self, X): return self.assign_to(X) def __str__(self): if self.values is None: if self.coord is None: return "Unassigned sklearn_xarray.Target without coordinate." else: return ( 'Unassigned sklearn_xarray.Target with coordinate "' + self.coord + '".' ) else: return "sklearn_xarray.Target with data:\n" + self.values.__str__() def __repr__(self): return self.__str__() def __array__(self, dtype=None): self._check_assigned() if not self.lazy or self.transform_func is None: return np.array(self.values, dtype=dtype) else: return np.array(self.transform_func(self.values), dtype=dtype) def _check_assigned(self): """ Check if this instance has been assigned data. """ if self.values is None and self.lazy: raise ValueError("This instance has not been assigned any data.") def _reduce(self, values): """ Reduce values to dimension(s). """ if self.dim is None: return values if isinstance(self.dim, str): dim = [self.dim] else: dim = self.dim if self.reduce_func is None: for d in values.dims: if d not in dim: values = values.isel(**{d: 0}) return values else: other_dims = [d for d in values.dims if d not in dim] return values.reduce(self.reduce_func, dim=other_dims) @property def shape(self): """ The shape of the transformed target. """ self._check_assigned() if ( self.lazy and self.transformer is not None and hasattr(self.transformer, "get_transformed_shape") ): return self.transformer.get_transformed_shape(self.values) else: return self.__array__().shape @property def ndim(self): """ The number of dimensions of the transformed target. """ self._check_assigned() if ( self.lazy and self.transformer is not None and hasattr(self.transformer, "get_transformed_shape") ): return len(self.transformer.get_transformed_shape(self.values)) else: return self.__array__().ndim
[docs] def assign_to(self, X): """ Assign this target to a DataArray or Dataset. Parameters ---------- X : xarray DataArray or Dataset The data whose coordinate is used as the target. Returns ------- self: The target itself. """ if self.coord is not None: self.values = self._reduce(X[self.coord]) else: self.values = self._reduce(X) if not self.lazy and self.transform_func is not None: self.values = self.transform_func(self.values) return self