Skip to content

Add validations for explainable model arguments in MimicExplainer #354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/interpret_community/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ class LightGBMParams(object):
"""Provide constants for LightGBM."""

CATEGORICAL_FEATURE = 'categorical_feature'
N_JOBS = 'n_jobs'
ALL = [CATEGORICAL_FEATURE, N_JOBS]


class LinearExplainableModelParams(object):
"""Provide constants for LinearExplainableModel."""
SPARSE_DATA = 'sparse_data'
ALL = [SPARSE_DATA]


class ShapValuesOutput(str, Enum):
Expand Down
57 changes: 48 additions & 9 deletions python/interpret_community/mimic/mimic_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
transform_with_datamapper

from ..common.blackbox_explainer import BlackBoxExplainer

from .model_distill import _model_distill, _inverse_soft_logit
from .models import LGBMExplainableModel
from .models import LGBMExplainableModel, LinearExplainableModel
from ..explanation.explanation import _create_local_explanation, _create_global_explanation, \
_aggregate_global_from_local_explanation, _aggregate_streamed_local_explanations, \
_create_raw_feats_global_explanation, _create_raw_feats_local_explanation, \
Expand All @@ -30,7 +29,7 @@
from ..dataset.dataset_wrapper import DatasetWrapper
from ..common.constants import ExplainParams, ExplainType, ModelTask, \
ShapValuesOutput, MimicSerializationConstants, ExplainableModelType, \
LightGBMParams, Defaults, Extension, ResetIndex
LightGBMParams, Defaults, Extension, ResetIndex, LinearExplainableModelParams
import logging
import json

Expand Down Expand Up @@ -236,6 +235,8 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
"""
if transformations is not None and explain_subset is not None:
raise ValueError("explain_subset not supported with transformations")
self._validate_explainable_model_args(explainable_model=explainable_model,
explainable_model_args=explainable_model_args)
self.reset_index = reset_index
self._datamapper = None
if transformations is not None:
Expand All @@ -250,8 +251,7 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
wrapped_model, eval_ml_domain = _wrap_model(model, initialization_examples, model_task, is_function)
super(MimicExplainer, self).__init__(wrapped_model, is_function=is_function,
model_task=eval_ml_domain, **kwargs)
if explainable_model_args is None:
explainable_model_args = {}

if categorical_features is None:
categorical_features = []
self._logger.debug('Initializing MimicExplainer')
Expand Down Expand Up @@ -288,7 +288,6 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
# Index the categorical string columns for training data
self._column_indexer = initialization_examples.string_index(columns=categorical_features)
self._one_hot_encoder = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this removed?

             explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features 

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I guess you moved it to line 347, I guess that's ok, my only slight concern is now we are doing the same checks in multiple places:

is_tree_model = explainable_model.explainable_model_type == ExplainableModelType.TREE_EXPLAINABLE_MODEL_TYPE
        if is_tree_model and self._supports_categoricals(explainable_model):

but it's not expensive so I think it's ok

explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features
else:
# One-hot-encode categoricals for models that don't support categoricals natively
self._column_indexer = initialization_examples.string_index(columns=categorical_features)
Expand All @@ -304,15 +303,55 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
if isinstance(training_data, DenseData):
training_data = training_data.data

explainable_model_args[ExplainParams.CLASSIFICATION] = self.predict_proba_flag
if self._supports_shap_values_output(explainable_model):
explainable_model_args[ExplainParams.SHAP_VALUES_OUTPUT] = shap_values_output
explainable_model_args = self._supplement_explainable_model_args(
explainable_model=explainable_model,
explainable_model_args=explainable_model_args,
categorical_features=categorical_features,
shap_values_output=shap_values_output)
self.surrogate_model = _model_distill(self.function, explainable_model, training_data,
original_training_data, explainable_model_args)
self._method = self.surrogate_model._method
self._original_eval_examples = None
self._allow_all_transformations = allow_all_transformations

def _validate_explainable_model_args(self, explainable_model, explainable_model_args):
if explainable_model_args is None:
return

if explainable_model == LGBMExplainableModel:
for linear_param in LinearExplainableModelParams.ALL:
if linear_param in explainable_model_args:
raise Exception(linear_param +
" found in params for LightGBM explainable model")

if explainable_model == LinearExplainableModel:
for lightgbm_param in LightGBMParams.ALL:
if lightgbm_param in explainable_model_args:
raise Exception(lightgbm_param +
" found in params for Linear explainable model")

all_supported_explainable_model_args = [LightGBMParams.ALL, LinearExplainableModelParams.ALL]
for explainable_model_arg in explainable_model_args:
if explainable_model_arg not in all_supported_explainable_model_args:
raise Exception(
"Found unsupported explainable model argument " + explainable_model_arg)

def _supplement_explainable_model_args(self, explainable_model, explainable_model_args,
categorical_features, shap_values_output):
if explainable_model_args is None:
explainable_model_args = {}

if explainable_model.explainable_model_type == ExplainableModelType.TREE_EXPLAINABLE_MODEL_TYPE and \
self._supports_categoricals(explainable_model):
explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features

explainable_model_args[ExplainParams.CLASSIFICATION] = self.predict_proba_flag

if self._supports_shap_values_output(explainable_model):
explainable_model_args[ExplainParams.SHAP_VALUES_OUTPUT] = shap_values_output

return explainable_model_args

def _get_surrogate_model_predictions(self, evaluation_examples):
"""Return the predictions given by the surrogate model.

Expand Down
26 changes: 25 additions & 1 deletion test/test_mimic_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sys import platform
from interpret_community.common.constants import ShapValuesOutput, ModelTask
from interpret_community.common.constants import ShapValuesOutput, ModelTask, \
LinearExplainableModelParams, LightGBMParams
from interpret_community.mimic.models.lightgbm_model import LGBMExplainableModel
from interpret_community.mimic.models.linear_model import LinearExplainableModel
from common_utils import create_timeseries_data, LIGHTGBM_METHOD, \
Expand Down Expand Up @@ -540,6 +541,29 @@ def test_dense_wide_data(self, mimic_explainer):
global_explanation = explainer.explain_global(df_X)
assert global_explanation.method == LIGHTGBM_METHOD

@pytest.mark.parametrize("error_config",
[(LGBMExplainableModel, {LinearExplainableModelParams.SPARSE_DATA: True}),
(LinearExplainableModel, {LightGBMParams.N_JOBS: -1}),
(LinearExplainableModel, {LightGBMParams.CATEGORICAL_FEATURE: []}),
(LGBMExplainableModel, {"unsupported": True}),
(LinearExplainableModel, {"unsupported": True})])
def test_validate_explainable_model_args(self, error_config, mimic_explainer):
num_features = 100
num_rows = 1000
test_size = 0.2
X, y = make_regression(n_samples=num_rows, n_features=num_features)
x_train, x_test, y_train, _ = train_test_split(X, y, test_size=test_size, random_state=42)

model = LinearRegression(normalize=True)
model.fit(x_train, y_train)

explainable_model = error_config[0]
explainable_model_args = error_config[1]
with pytest.raises(Exception):
mimic_explainer(model, x_train, explainable_model,
explainable_model_args=explainable_model_args,
augment_data=False)

@property
def iris_overall_expected_features(self):
return [['petal length', 'petal width', 'sepal width', 'sepal length'],
Expand Down