Source code for mlexplainer.utils.quantiles

"""Utility functions for quantile calculations in ML Explainer."""

from typing import Optional, Union


from pandas import DataFrame, Series, isna, notna, qcut, merge
from numpy import arange, inf, isclose, nan


[docs] def is_in_quantile( value: int, quantile_values: list[Union[int, float]] ) -> Union[int, float]: """Return the quantile of a value, given a list of quantiles. Args: value (int): Search value. quantile_values (list[Union[int, float]]): List of quantiles. Returns: Union[int, float]: Upper bound of the quantile. """ if isna(value): return -1 for quantile in quantile_values: if value <= quantile: return quantile # return the last quantile if value exceeds all return inf
[docs] def nb_min_quantiles(x: DataFrame, q: Optional[int] = None) -> int: """Calculate the number of quantiles to use for a feature. Args: x (DataFrame): DataFrame to calculate the number of quantiles for. q (int, optional): Number of quantiles. Defaults to None. Returns: int: Final number of quantiles to use. """ if q is None: nb_quantile = 20 # min 5% of observations in each quantile min_size = 0 # min size after cutting while min_size < 0.01: cutting = qcut(x, nb_quantile, duplicates="drop").value_counts( dropna=False, normalize=True ) # using only not NaN cut: cannot cut according to these type of observations not_nan = [u for u in cutting.index if notna(u)] min_size = cutting[not_nan].min() # using a smaller value of quantile to respect constraint nb_quantile -= 1 else: nb_quantile = q # correction for NaN value: if it is relevant to consider it as a class or not pct_nan = x.isna().mean() nb_quantile = min(int(nb_quantile * (1 - pct_nan)) + 1, nb_quantile) return nb_quantile
[docs] def group_values( x: Series, y: Series, q: Optional[int], threshold_nb_values: float = 15 ) -> tuple[DataFrame, int]: """Create a new DataFrame of cut values. This function groups the values of a feature into quantiles and computes the mean of the target variable for each group. It also counts the number of observations in each group. Args: x (Series): Feature values. y (Series): Target values. q (int): Number of quantiles. Returns: DataFrame: Grouped values with statistics. int: Used quantiles. """ # Check if the series is full of missing values if x.isna().all(): results = DataFrame( {"group": [nan], "target": [y.mean()], "volume": [len(x)]} ) return results, 1 df = DataFrame({"value": x, "target": y}) if q is None: df["group"] = df["value"] elif x.nunique() <= threshold_nb_values: df["group"] = df["value"] else: q = nb_min_quantiles(x, q) quantiles = arange(1 / q, 1, 1 / q) quantiles = quantiles[~isclose(quantiles, 1)].flatten() quantiles_values = list(df["value"].quantile(quantiles)) + [inf] df["group"] = df["value"].apply( lambda u: is_in_quantile(u, quantiles_values) ) # handle missing value replaced_nan_value = df["group"].min() - 10 df.loc[df["value"].isna(), "group"] = replaced_nan_value # compute statistics stats_groupby = df.groupby("group").mean().sort_index().reset_index() volume = ( df["group"] .value_counts() .to_frame() .sort_index() .reset_index() .sort_index() ) results = merge(volume, stats_groupby, how="left", on="group") # if there is at least one missing value in the dataset if df["value"].isna().any(): results["group"] = results["group"].replace( results["group"].min(), nan ) return results.sort_index(), q