Source code for sklearn_xarray.datasets

""" ``sklearn_xarray.datasets`` """

import numpy as np
import pandas as pd
import xarray as xr


[docs]def load_dummy_dataarray(): """ Load a DataArray for demonstration purposes. """ return xr.DataArray( np.random.random((100, 10)), coords={"sample": range(100), "feature": range(10)}, dims=("sample", "feature"), )
[docs]def load_dummy_dataset(): """ Load a Dataset for demonstration purposes. """ return xr.Dataset( {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, coords={"sample": range(100), "feature": range(10)}, )
[docs]def load_digits_dataarray(load_images=False, nan_probability=0): """ Load a the 'digits' dataset from sklearn as a DataArray. Parameters ---------- load_images : bool, optional If true, the DataArray will contain the two-dimensional images as data instead of the vectorized samples. nan_probability : float between 0 and 1 The probability with which a sample is injected with NaN values. For demonstration purposes only. """ from sklearn.datasets import load_digits if load_images: bunch = load_digits() data = bunch.images if nan_probability > 0: for i in range(data.shape[0]): if np.random.rand(1) < nan_probability: data[i, 0, 0] = np.nan return xr.DataArray( data, coords={ "sample": range(data.shape[0]), "row": range(data.shape[1]), "col": range(data.shape[2]), "digit": (["sample"], bunch.target), }, dims=("sample", "row", "col"), ) else: data, target = load_digits(return_X_y=True) if nan_probability > 0: for i in range(data.shape[0]): if np.random.rand(1) < nan_probability: data[i, 0] = np.nan return xr.DataArray( data, coords={ "sample": range(data.shape[0]), "feature": range(data.shape[1]), "digit": (["sample"], target), }, dims=("sample", "feature"), )
[docs]def load_wisdm_dataarray( url="http://www.cis.fordham.edu/wisdm/includes/" "datasets/latest/WISDM_ar_latest.tar.gz", file="WISDM_ar_v1.1/WISDM_ar_v1.1_raw.txt", folder="data/", tmp_file="widsm.tar.gz", ): """ Load the WISDM activity recognition dataset. Parameters ---------- url : str, optional The URL of the dataset. file : str, optional The file containing the data. folder : str, optional The folder where the data will be downloaded and extracted to. tmp_file : str, optional The name of the temporary .tar file in the folder. Returns ------- X: xarray DataArray The loaded dataset. """ import os import tarfile import six.moves.urllib.request as ul if not os.path.isfile(os.path.join(folder, file)): ul.urlretrieve(url, tmp_file) tar = tarfile.open(tmp_file) tar.extractall(folder) tar.close() os.remove(tmp_file) column_names = ["subject", "activity", "timestamp", "x", "y", "z"] df = pd.read_csv( os.path.join(folder, file), header=None, names=column_names, comment=";", ) time = pd.date_range(start=0, periods=df.shape[0], freq="50ms") coords = { "subject": ("sample", df.subject), "activity": ("sample", df.activity), "sample": time, "axis": ["x", "y", "z"], } X = xr.DataArray(df.iloc[:, 3:6], coords=coords, dims=("sample", "axis")) return X