Source code for pytorch_utils.datasets

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Literal, Optional

import numpy as np
import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.utils.validation import check_is_fitted
from torch.utils.data import Dataset

from pytorch_utils.pandas.utils import DataFrameRow, ListDataFrameRows
from pytorch_utils.dataset_configurations import (
    AugmentedBernoulliDatasetConfig,
    DataAugmentationConfig,
)


[docs] class MLStage(Enum): fit = auto() validate = auto() test = auto() predict = auto()
[docs] @dataclass(frozen=True) class AugmentedBernoulliDataset(Dataset): """ Implements storage-efficient data augmentation of Bernoulli samples (binary outcomes: successful or not) as well as data transformations (e.g., scaling, encoding, ...). If `is_success` is set to None, only features are generated (labels are dropped). This is useful for prediction sets. """ data: ListDataFrameRows is_success: Optional[ bool ] # if None then only features are generated, if True labels are set to 1, If False labels are set to 0 augmented_col: str fitted_preprocessing_pipeline: Optional[Pipeline] = None data_augmentation_scaling_factors: np.ndarray = np.array([1.0]) label_col: str = "success_labels" # name of label column labels_dtype: np.dtype = np.dtype("int32") sample_weight_col: Optional[str] = None min_augmented_value: float = -float("inf") max_augmented_value: float = float("inf") def __post_init__(self) -> None: # Sort data_augmentation scaling factors # This ensures that two dataclasses that only differ by the ordering of # `data_augmentation_scaling_factors` are considered equal object.__setattr__( self, "data_augmentation_scaling_factors", sorted(self.data_augmentation_scaling_factors), ) # Check attributes if self.augmented_col not in self.data.columns: raise ValueError("`augmented_col` should be a column of dataframe `data`") if self.fitted_preprocessing_pipeline: check_is_fitted( self.fitted_preprocessing_pipeline, msg="fitted_preprocessing_pipeline is not fitted", ) if len(self.data_augmentation_scaling_factors) == 0: raise ValueError("`data_augmentation_scaling_factors` should not be an empty array") # Compute augmented_df_indices and augmented_df_lengths_cumsum based on data_augmentation_scaling_factors. object.__setattr__(self, "augmented_df_indices", self._build_augmented_df_indices()) object.__setattr__( self, "augmented_df_lengths_cumsum", self._build_augmented_df_lengths_cumsum(), ) object.__setattr__( self, "indices_rows_mapping", {row.index: i for i, row in enumerate(self.data)}, ) # Set pipeline ouput to pandas if self.fitted_preprocessing_pipeline: self.fitted_preprocessing_pipeline.set_output(transform="pandas")
[docs] @classmethod def from_config( cls, config: AugmentedBernoulliDatasetConfig, ml_stage: Literal[MLStage.fit, MLStage.validate, MLStage.test], fitted_preprocessing_pipeline: Optional[Pipeline] = None, label_col: str = "success_labels", labels_dtype: np.dtype = np.dtype("int32"), sample_weight_col: Optional[str] = None, ) -> AugmentedBernoulliDataset: if ml_stage not in [MLStage.fit, MLStage.validate, MLStage.test]: raise ValueError( "ml_stage should be one of MLStage.fit, MLStage.validate or MLStage.test" ) return cls( data=ListDataFrameRows.from_pandas_df( config.training_data if ml_stage is MLStage.fit else config.validation_data if ml_stage is MLStage.validate else config.test_data if ml_stage is MLStage.test else None ), is_success=config.is_success, augmented_col=config.augmented_col, fitted_preprocessing_pipeline=fitted_preprocessing_pipeline, data_augmentation_scaling_factors=config.data_augmentation_scaling_factors, label_col=label_col, labels_dtype=labels_dtype, sample_weight_col=sample_weight_col, min_augmented_value=config.data_augmentation_config.min_value, max_augmented_value=config.data_augmentation_config.max_value, )
@property def dataframe(self) -> pd.DataFrame: return self.data.dataframe
[docs] def clear_data(self) -> AugmentedBernoulliDataset: object.__setattr__(self, "data", self.data[:0]) return self
def _build_augmented_df_indices(self) -> List[pd.Index]: df = self.data.build_dataframe([self.augmented_col]) return [ df[ DataAugmentationConfig.scaling_filter( df, scaling_factor, self.augmented_col, min_value=self.min_augmented_value, max_value=self.max_augmented_value, ) ].index for scaling_factor in self.data_augmentation_scaling_factors ] def _build_augmented_df_lengths_cumsum(self) -> np.ndarray: return np.cumsum([len(indices) for indices in getattr(self, "augmented_df_indices")]) @property def raw_feature_names(self): return self.data.columns @property def transformed_feature_names(self): return ( self.fitted_preprocessing_pipeline.get_feature_names_out() if self.fitted_preprocessing_pipeline else self.data.columns ) def __getitem__(self, index) -> DataFrameRow: """ Implicit assumption in the following implementation: the preprocessing pipeline does not modify the number of rows. """ if index > len(self) - 1: raise IndexError("single positional indexer is out-of-bounds") # Compute actual index # ------------------- augmented_df_id = np.searchsorted( getattr(self, "augmented_df_lengths_cumsum"), index, side="right" ) augmented_df_index = ( index - getattr(self, "augmented_df_lengths_cumsum")[augmented_df_id - 1] if augmented_df_id > 0 else index ) actual_index = getattr(self, "augmented_df_indices")[augmented_df_id][augmented_df_index] row_number = getattr(self, "indices_rows_mapping")[actual_index] # Get raw features # ------------------- raw_features: DataFrameRow = self.data[row_number] # Apply scaling factor # ------------------- scaling_factor = self.data_augmentation_scaling_factors[augmented_df_id] raw_features = raw_features.set( column=self.augmented_col, value=scaling_factor * raw_features[self.augmented_col], ) # Apply transformations # ------------------- sample = ( DataFrameRow.from_single_row_df( self.fitted_preprocessing_pipeline.transform(raw_features.single_row_df) ) if self.fitted_preprocessing_pipeline else raw_features ) # Add label # ------------------- if self.is_success is not None: sample = sample.set( column=self.label_col, value=self.is_success, dtype=self.labels_dtype ) # Add weights back if removed by preprocessing pipeline # ------------------- if (self.sample_weight_col is not None) and (self.sample_weight_col not in sample): sample = sample.set( column=self.sample_weight_col, value=raw_features[self.sample_weight_col], dtype=raw_features.dtypes[self.sample_weight_col], ) return sample def __len__(self): return self.augmented_df_lengths_cumsum[-1]