Source code for mlexplainer.explainers.shap.multilabel

"""MultilabelMLExplainer for multilabel classification tasks.
This module provides an implementation of the BaseMLExplainer for multilabel classification tasks,
including methods to explain numerical and categorical features using SHAP values.
"""

from typing import Callable, List, Union

import matplotlib.pyplot as plt
from pandas import DataFrame, Series
from streamlit import pyplot

from mlexplainer.core import BaseMLExplainer
from mlexplainer.explainers.shap.wrapper import ShapWrapper
from mlexplainer.visualization import (
    plot_feature_target_numerical_multilabel,
    plot_feature_target_categorical_multilabel,
    plot_shap_values_numerical_multilabel,
    plot_shap_values_categorical_multilabel,
)
from mlexplainer.utils.data_processing import (
    get_index_of_features,
    calculate_min_max_value,
)
from mlexplainer.validation.feature_interpretation import (
    validate_single_feature_interpretation,
)


[docs] class MultilabelMLExplainer(BaseMLExplainer): """MultilabelMLExplainer for multilabel classification tasks."""
[docs] def __init__( self, x_train: DataFrame, y_train: Series, features: List[str], model: Callable, global_explainer: bool = True, local_explainer: bool = True, ): """ Initialize the MultilabelMLExplainer with training data, features, and model. Args: x_train (DataFrame): Training feature values. y_train (Series): Training target values (multilabel). features (List[str]): List of feature names to interpret. model (Callable): The machine learning model to explain. global_explainer (bool): Whether to use a global explainer. Defaults to True. local_explainer (bool): Whether to use a local explainer. Defaults to True. Raises: ValueError: If x_train or y_train is None, or if features are not provided or not present in x_train. ValueError: If any feature in features is not present in x_train. ValueError: If no features are provided. """ if y_train.nunique() < 2: raise ValueError( "y_train must have at least two unique values for multilabel classification." ) super().__init__( x_train, y_train, features, model, global_explainer, local_explainer, ) self.shap_values_train = ShapWrapper(self.model).calculate( dataframe=self.x_train, features=self.features ) self.modalities = self.y_train.unique() self.ymean_train = self.y_train.mean()
[docs] def explain( self, features_to_explain: Union[list[str], None] = None, **kwargs ): """Explain the features for multilabel classification. This method interprets the features based on the training data and SHAP values. Args: **kwargs: Additional keyword arguments for customization, such as: - figsize: Tuple for figure size (default: (15, 8)) - dpi: Dots per inch for the plot (default: 100) - q: Number of quantiles for plotting (default: 20) - threshold_nb_values: Threshold for number of values in categorical features (default: 15) """ if features_to_explain is None: features_to_explain = self.features if self.global_explainer: # plot a global features importance self._explain_global_features(**kwargs) if self.local_explainer: # check if num features are corrects self._explain_numerical(features_to_explain, **kwargs) # check if cat features are corrects self._explain_categorical(features_to_explain, **kwargs)
[docs] def correctness_features( self, q=None, ) -> dict: """Analyze the correctness of the analysis for every feature. This method validates interpretation consistency between actual target rates and SHAP values for all features in the explainer for multilabel classification. Args: q (int): Number of quantiles for continuous features. If None, uses adaptive quantiles. Defaults to None. Returns: dict: Dictionary with feature names as keys and modality-specific correctness results as values. """ shap_values = self.shap_values_train if shap_values is None: return { feature: {modality: False for modality in self.modalities} for feature in self.features } results = {} for feature in self.features: # Get feature index for SHAP values feature_index = get_index_of_features(self.x_train, feature) feature_results = {} # Validate for each modality for i, modality in enumerate(self.modalities): # Create binary target for this modality y_binary = (self.y_train == modality).astype(int) ymean_binary = y_binary.mean() # Get SHAP values for this modality if shap_values.ndim == 3: # Multi-output SHAP values feature_shap_values = shap_values[:, feature_index, i] else: # Single output - use as is feature_shap_values = shap_values[:, feature_index] feature_results[modality] = ( validate_single_feature_interpretation( x_train=self.x_train, y_binary=y_binary, feature=feature, feature_shap_values=feature_shap_values, numerical_features=self.numerical_features, ymean_binary=ymean_binary, q=q, ) ) results[feature] = feature_results return results
def _explain_global_features(self, **kwargs): """Interpret global features for multilabel classification. This method calculates and plots the global feature importance based on SHAP values. Args: **kwargs: Additional keyword arguments for customization, such as: - figsize: Tuple for figure size (default: (15, 8)) """ # Calculate absolute SHAP values across all modalities if self.shap_values_train.ndim == 3: # For multi-output, sum across modalities absolute_shap_values = DataFrame( self.shap_values_train.mean(axis=2), columns=self.features ).apply(abs) else: absolute_shap_values = DataFrame( self.shap_values_train, columns=self.features ).apply(abs) # Calculate relative importance mean_absolute_shap_values = absolute_shap_values.mean().sum() relative_importance = ( ( absolute_shap_values.mean().divide(mean_absolute_shap_values) * 100 ) .reset_index(drop=False) .rename(columns={"index": "features", 0: "importances"}) .sort_values(by="importances", ascending=True) ) figsize = kwargs.get("figsize", (15, 8)) fig, ax = plt.subplots(1, 1, figsize=figsize) # Plot with horizontal bar chart ax.barh( relative_importance["features"], relative_importance["importances"] ) # Set title and labels ax.set_title( ( "Global Feature Importance for Multilabel Classification", " (Mean of the absolute SHAP values)", ) ) ax.set_xlabel("Relative Importance (%)") ax.set_ylabel("Features") for _, row in relative_importance.iterrows(): ax.text( row.importances, row.features, s=" " + str(round(row.importances, 1)) + "%.", va="center", ) demo_mode = kwargs.get("demo_mode", False) if demo_mode: pyplot(fig) plt.show() def _explain_numerical(self, features_to_explain: list[str], **kwargs): """Interpret numerical features for multilabel classification. This method calculates and plots the relationship between numerical features, target values, and SHAP values. Args: **kwargs: Additional keyword arguments for customization, such as: - figsize: Tuple for figure size (default: (15, 8)) - dpi: Dots per inch for the plot (default: 100) - q: Number of quantiles for plotting (default: 20) - threshold_nb_values: Threshold for number of values in categorical features (default: 15)""" numerical_features_to_explain = [ feature for feature in features_to_explain if feature in self.numerical_features ] for feature in numerical_features_to_explain: min_value_train, max_value_train = calculate_min_max_value( self.x_train, feature ) # calculate delta delta = (max_value_train - min_value_train) / 10 # Plot feature-target relationship with SHAP values figsize = kwargs.get("figsize", (15, 8)) dpi = kwargs.get("dpi", 100) # _, ax = plt.subplots(figsize=figsize, dpi=dpi) q = kwargs.get("q", 20) threshold_nb_values = kwargs.get("threshold_nb_values", 15) fig, axes = plot_feature_target_numerical_multilabel( self.x_train, self.y_train, feature, q, delta, figsize, dpi, threshold_nb_values=threshold_nb_values, ) axes = plot_shap_values_numerical_multilabel( x_train=self.x_train, y_train=self.y_train, feature=feature, shap_values_train=self.shap_values_train, axes=axes, delta=delta, ) demo_mode = kwargs.get("demo_mode", False) if demo_mode: pyplot(fig) plt.show() plt.close() def _explain_categorical(self, features_to_explain: list[str], **kwargs): """Interpret categorical features for multilabel classification. This method calculates and plots the relationship between categorical features, target values, and SHAP values. Args: **kwargs: Additional keyword arguments for customization, such as: - figsize: Tuple for figure size (default: (15, 8)) - dpi: Dots per inch for the plot (default: 200) - color: Color for the plot (default: (0.28, 0.18, 0.71)) """ categorical_features_to_explain = [ feature for feature in features_to_explain if feature in self.categorical_features ] for feature in categorical_features_to_explain: # Plot feature-target relationship with SHAP values figsize = kwargs.get("figsize", (15, 8)) dpi = kwargs.get("dpi", 200) color = kwargs.get("color", (0.28, 0.18, 0.71)) # little refactorization for missing values and interpretability self.x_train[feature] = self.x_train[feature].astype(str) self.x_train[feature] = self.x_train[feature].fillna("missing_value") fig, axes = plot_feature_target_categorical_multilabel( self.x_train, self.y_train, feature, self.modalities, figsize, dpi, color, ) axes = plot_shap_values_categorical_multilabel( x_train=self.x_train, y_train=self.y_train, feature=feature, axes=axes, shap_values_train=self.shap_values_train, ) demo_mode = kwargs.get("demo_mode", False) if demo_mode: pyplot(fig) plt.show() plt.close()