"""MultilabelMLExplainer for multilabel classification tasks.
This module provides an implementation of the BaseMLExplainer for multilabel classification tasks,
including methods to explain numerical and categorical features using SHAP values.
"""
from typing import Callable, List, Union
import matplotlib.pyplot as plt
from pandas import DataFrame, Series
from streamlit import pyplot
from mlexplainer.core import BaseMLExplainer
from mlexplainer.explainers.shap.wrapper import ShapWrapper
from mlexplainer.visualization import (
plot_feature_target_numerical_multilabel,
plot_feature_target_categorical_multilabel,
plot_shap_values_numerical_multilabel,
plot_shap_values_categorical_multilabel,
)
from mlexplainer.utils.data_processing import (
get_index_of_features,
calculate_min_max_value,
)
from mlexplainer.validation.feature_interpretation import (
validate_single_feature_interpretation,
)
[docs]
class MultilabelMLExplainer(BaseMLExplainer):
"""MultilabelMLExplainer for multilabel classification tasks."""
[docs]
def __init__(
self,
x_train: DataFrame,
y_train: Series,
features: List[str],
model: Callable,
global_explainer: bool = True,
local_explainer: bool = True,
):
"""
Initialize the MultilabelMLExplainer with training data, features, and model.
Args:
x_train (DataFrame): Training feature values.
y_train (Series): Training target values (multilabel).
features (List[str]): List of feature names to interpret.
model (Callable): The machine learning model to explain.
global_explainer (bool): Whether to use a global explainer.
Defaults to True.
local_explainer (bool): Whether to use a local explainer.
Defaults to True.
Raises:
ValueError: If x_train or y_train is None, or if features are not provided
or not present in x_train.
ValueError: If any feature in features is not present in x_train.
ValueError: If no features are provided.
"""
if y_train.nunique() < 2:
raise ValueError(
"y_train must have at least two unique values for multilabel classification."
)
super().__init__(
x_train,
y_train,
features,
model,
global_explainer,
local_explainer,
)
self.shap_values_train = ShapWrapper(self.model).calculate(
dataframe=self.x_train, features=self.features
)
self.modalities = self.y_train.unique()
self.ymean_train = self.y_train.mean()
[docs]
def explain(
self, features_to_explain: Union[list[str], None] = None, **kwargs
):
"""Explain the features for multilabel classification.
This method interprets the features based on the training data and SHAP values.
Args:
**kwargs: Additional keyword arguments for customization, such as:
- figsize: Tuple for figure size (default: (15, 8))
- dpi: Dots per inch for the plot (default: 100)
- q: Number of quantiles for plotting (default: 20)
- threshold_nb_values: Threshold for number of values in categorical
features (default: 15)
"""
if features_to_explain is None:
features_to_explain = self.features
if self.global_explainer:
# plot a global features importance
self._explain_global_features(**kwargs)
if self.local_explainer:
# check if num features are corrects
self._explain_numerical(features_to_explain, **kwargs)
# check if cat features are corrects
self._explain_categorical(features_to_explain, **kwargs)
[docs]
def correctness_features(
self,
q=None,
) -> dict:
"""Analyze the correctness of the analysis for every feature.
This method validates interpretation consistency between actual target rates
and SHAP values for all features in the explainer for multilabel classification.
Args:
q (int): Number of quantiles for continuous features.
If None, uses adaptive quantiles. Defaults to None.
Returns:
dict: Dictionary with feature names as keys and modality-specific
correctness results as values.
"""
shap_values = self.shap_values_train
if shap_values is None:
return {
feature: {modality: False for modality in self.modalities}
for feature in self.features
}
results = {}
for feature in self.features:
# Get feature index for SHAP values
feature_index = get_index_of_features(self.x_train, feature)
feature_results = {}
# Validate for each modality
for i, modality in enumerate(self.modalities):
# Create binary target for this modality
y_binary = (self.y_train == modality).astype(int)
ymean_binary = y_binary.mean()
# Get SHAP values for this modality
if shap_values.ndim == 3: # Multi-output SHAP values
feature_shap_values = shap_values[:, feature_index, i]
else: # Single output - use as is
feature_shap_values = shap_values[:, feature_index]
feature_results[modality] = (
validate_single_feature_interpretation(
x_train=self.x_train,
y_binary=y_binary,
feature=feature,
feature_shap_values=feature_shap_values,
numerical_features=self.numerical_features,
ymean_binary=ymean_binary,
q=q,
)
)
results[feature] = feature_results
return results
def _explain_global_features(self, **kwargs):
"""Interpret global features for multilabel classification.
This method calculates and plots the global feature importance based on
SHAP values.
Args:
**kwargs: Additional keyword arguments for customization, such as:
- figsize: Tuple for figure size (default: (15, 8))
"""
# Calculate absolute SHAP values across all modalities
if self.shap_values_train.ndim == 3:
# For multi-output, sum across modalities
absolute_shap_values = DataFrame(
self.shap_values_train.mean(axis=2), columns=self.features
).apply(abs)
else:
absolute_shap_values = DataFrame(
self.shap_values_train, columns=self.features
).apply(abs)
# Calculate relative importance
mean_absolute_shap_values = absolute_shap_values.mean().sum()
relative_importance = (
(
absolute_shap_values.mean().divide(mean_absolute_shap_values)
* 100
)
.reset_index(drop=False)
.rename(columns={"index": "features", 0: "importances"})
.sort_values(by="importances", ascending=True)
)
figsize = kwargs.get("figsize", (15, 8))
fig, ax = plt.subplots(1, 1, figsize=figsize)
# Plot with horizontal bar chart
ax.barh(
relative_importance["features"], relative_importance["importances"]
)
# Set title and labels
ax.set_title(
(
"Global Feature Importance for Multilabel Classification",
" (Mean of the absolute SHAP values)",
)
)
ax.set_xlabel("Relative Importance (%)")
ax.set_ylabel("Features")
for _, row in relative_importance.iterrows():
ax.text(
row.importances,
row.features,
s=" " + str(round(row.importances, 1)) + "%.",
va="center",
)
demo_mode = kwargs.get("demo_mode", False)
if demo_mode:
pyplot(fig)
plt.show()
def _explain_numerical(self, features_to_explain: list[str], **kwargs):
"""Interpret numerical features for multilabel classification.
This method calculates and plots the relationship between numerical features,
target values, and SHAP values.
Args:
**kwargs: Additional keyword arguments for customization, such as:
- figsize: Tuple for figure size (default: (15, 8))
- dpi: Dots per inch for the plot (default: 100)
- q: Number of quantiles for plotting (default: 20)
- threshold_nb_values: Threshold for number of values in categorical
features (default: 15)"""
numerical_features_to_explain = [
feature
for feature in features_to_explain
if feature in self.numerical_features
]
for feature in numerical_features_to_explain:
min_value_train, max_value_train = calculate_min_max_value(
self.x_train, feature
)
# calculate delta
delta = (max_value_train - min_value_train) / 10
# Plot feature-target relationship with SHAP values
figsize = kwargs.get("figsize", (15, 8))
dpi = kwargs.get("dpi", 100)
# _, ax = plt.subplots(figsize=figsize, dpi=dpi)
q = kwargs.get("q", 20)
threshold_nb_values = kwargs.get("threshold_nb_values", 15)
fig, axes = plot_feature_target_numerical_multilabel(
self.x_train,
self.y_train,
feature,
q,
delta,
figsize,
dpi,
threshold_nb_values=threshold_nb_values,
)
axes = plot_shap_values_numerical_multilabel(
x_train=self.x_train,
y_train=self.y_train,
feature=feature,
shap_values_train=self.shap_values_train,
axes=axes,
delta=delta,
)
demo_mode = kwargs.get("demo_mode", False)
if demo_mode:
pyplot(fig)
plt.show()
plt.close()
def _explain_categorical(self, features_to_explain: list[str], **kwargs):
"""Interpret categorical features for multilabel classification.
This method calculates and plots the relationship between categorical features,
target values, and SHAP values.
Args:
**kwargs: Additional keyword arguments for customization, such as:
- figsize: Tuple for figure size (default: (15, 8))
- dpi: Dots per inch for the plot (default: 200)
- color: Color for the plot (default: (0.28, 0.18, 0.71))
"""
categorical_features_to_explain = [
feature
for feature in features_to_explain
if feature in self.categorical_features
]
for feature in categorical_features_to_explain:
# Plot feature-target relationship with SHAP values
figsize = kwargs.get("figsize", (15, 8))
dpi = kwargs.get("dpi", 200)
color = kwargs.get("color", (0.28, 0.18, 0.71))
# little refactorization for missing values and interpretability
self.x_train[feature] = self.x_train[feature].astype(str)
self.x_train[feature] = self.x_train[feature].fillna("missing_value")
fig, axes = plot_feature_target_categorical_multilabel(
self.x_train,
self.y_train,
feature,
self.modalities,
figsize,
dpi,
color,
)
axes = plot_shap_values_categorical_multilabel(
x_train=self.x_train,
y_train=self.y_train,
feature=feature,
axes=axes,
shap_values_train=self.shap_values_train,
)
demo_mode = kwargs.get("demo_mode", False)
if demo_mode:
pyplot(fig)
plt.show()
plt.close()