hlnicholls commited on
Commit
8a6cf88
1 Parent(s): 15625ec

feat: updated interace

Browse files
__pycache__/dynamic_shap_plot.cpython-38.pyc ADDED
Binary file (3.08 kB). View file
 
__pycache__/dynamic_shap_plots.cpython-38.pyc ADDED
Binary file (8.14 kB). View file
 
__pycache__/shap_plots.cpython-38.pyc ADDED
Binary file (16.9 kB). View file
 
app.py CHANGED
@@ -7,11 +7,12 @@ import sklearn
7
  import catboost
8
  import shap
9
  from shap_plots import shap_summary_plot
10
- from dynamic_shap_plot import matplotlib_to_plotly, summary_plot_plotly_fig
11
  import plotly.tools as tls
12
- import dash_core_components as dcc
13
- import matplotlib
14
  import plotly.graph_objs as go
 
15
  try:
16
  import matplotlib.pyplot as pl
17
  from matplotlib.colors import LinearSegmentedColormap
@@ -21,133 +22,163 @@ except ImportError:
21
 
22
  st.set_option('deprecation.showPyplotGlobalUse', False)
23
 
24
- seed=42
25
 
26
  annotations = pd.read_csv("all_genes_merged_ml_data.csv")
27
- # TODO remove this placeholder when imputation is finished:
28
  annotations.fillna(0, inplace=True)
29
  annotations = annotations.set_index("Gene")
30
 
31
- # Read in best_model_fitted.pkl as catboost_model
32
- model_path = "best_model_fitted.pkl" # Update this path if your model is stored elsewhere
33
  with open(model_path, 'rb') as file:
34
  catboost_model = pickle.load(file)
35
 
36
- # For a multi-class classification model, obtaining probabilities per class
37
  probabilities = catboost_model.predict_proba(annotations)
38
-
39
- # Creating a DataFrame for these probabilities
40
- # Assuming classes are ordered as 'most likely', 'probable', and 'least likely' in the model
41
- prob_df = pd.DataFrame(probabilities,
42
- index=annotations.index,
43
- columns=['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely'])
44
-
45
- # Dynamically including all original features from annotations plus the new probability columns
46
  df_total = pd.concat([prob_df, annotations], axis=1)
47
 
 
 
 
 
48
 
49
  st.title('Blood Pressure Gene Prioritisation Post-GWAS')
50
- st.markdown("""
51
- A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.
52
-
53
-
54
- """)
55
-
56
- collect_genes = lambda x : [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""]
57
 
 
 
58
  input_gene_list = st.text_input("Input a list of multiple HGNC genes (enter comma separated):")
59
  gene_list = collect_genes(input_gene_list)
60
  explainer = shap.TreeExplainer(catboost_model)
61
 
62
  @st.cache_data
63
  def convert_df(df):
64
- return df.to_csv(index=False).encode('utf-8')
65
 
66
  probability_columns = ['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely']
67
  features_list = [column for column in df_total.columns if column not in probability_columns]
68
  features = df_total[features_list]
69
 
70
- if len(gene_list) > 1:
71
- df = df_total[df_total.index.isin(gene_list)]
72
- df['Gene'] = df.index # Ensure 'Gene' is a column if it's not already
73
- df.reset_index(drop=True, inplace=True)
74
-
75
- # Including Gene, probability columns, and all other features
76
- required_columns = ['Gene'] + probability_columns + [col for col in df.columns if col not in probability_columns and col != 'Gene']
77
- df = df[required_columns]
78
- st.dataframe(df)
79
-
80
- # Assuming you want to download the genes with their probabilities
81
- output = df[['Gene'] + probability_columns]
82
- csv = convert_df(output)
83
- st.download_button(
84
- "Download Gene Prioritisation",
85
- csv,
86
- "bp_gene_prioritisation.csv",
87
- "text/csv",
88
- key='download-csv'
89
- )
90
-
91
- # For SHAP values, assuming explainer is already fitted to your model
92
- df_shap = df.drop(columns=probability_columns + ['Gene']) # Exclude non-feature columns
93
- shap_values = explainer.shap_values(df_shap)
94
-
95
- # Handle multiclass scenario: SHAP values will be a list of matrices, one per class
96
- # Plotting the summary plot for the first class as an example
97
- # You may loop through each class or handle it differently based on your needs
98
- class_index = 0 # Example: plotting for the first class
99
- shap.summary_plot(shap_values[class_index], df_shap, show=False)
100
- st.pyplot(bbox_inches='tight')
101
- st.caption("SHAP Summary Plot of All Input Genes")
102
-
103
- else:
104
- pass
105
 
 
 
106
 
107
- input_gene = st.text_input("Input an individual HGNC gene:")
108
- df2 = df_total[df_total.index == input_gene]
109
- df2['Gene'] = df2.index
110
- df2.reset_index(drop=True, inplace=True)
111
 
112
- # Ensure the DataFrame includes the CatBoost model's probability columns
113
- # And assuming all features are desired in the output
114
- probability_columns = ['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely']
115
- required_columns = ['Gene'] + probability_columns + [col for col in df2.columns if col not in probability_columns and col != 'Gene']
116
- df2 = df2[required_columns]
117
- st.dataframe(df2)
 
 
 
 
 
 
 
 
 
 
118
 
119
- if input_gene:
120
- if ' ' in input_gene or ',' in input_gene:
121
- st.write('Input Error: Please input only a single HGNC gene name with no white spaces or commas.')
122
  else:
123
- df2_shap = df_total.loc[[input_gene], [col for col in df_total.columns if col not in probability_columns + ['Gene']]]
124
-
125
- if df2_shap.shape[0] > 0: # Check if the gene exists in the DataFrame
126
- shap_values = explainer.shap_values(df2_shap)
127
-
128
- # Adjust for multiclass: Select SHAP values for the predicted class (or a specific class)
129
- predicted_class_index = catboost_model.predict(df2_shap).item() # Assuming predict returns the class index
130
- class_shap_values = shap_values[predicted_class_index]
131
- class_expected_value = explainer.expected_value[predicted_class_index]
132
 
133
- # Since force_plot doesn't directly support multiclass, consider using waterfall_plot or decision_plot
134
- # Here's an example using waterfall_plot for the first feature set's prediction
135
- shap.plots.waterfall(shap_values=class_shap_values[0], max_display=10, show=False)
136
- st.pyplot(bbox_inches='tight')
137
- else:
138
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- st.markdown("""
141
- ### Total Gene Prioritisation Results:
142
- """)
143
-
144
- df_total_output = df_total
145
- df_total_output['Gene'] = df_total_output.index
146
- df_total_output.reset_index(drop=True, inplace=True)
147
- #df_total_output = df_total_output[['Gene','XGB_Score', 'mousescore_Exomiser',
148
- # 'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
149
- # 'HIPred',
150
- # 'Cells - EBV-transformed lymphocytes_GTExTPM',
151
- # 'Pituitary_GTExTPM',
152
- # 'IPA_BP_annotation']]
153
- st.dataframe(df_total_output)
 
7
  import catboost
8
  import shap
9
  from shap_plots import shap_summary_plot
10
+ from dynamic_shap_plots import matplotlib_to_plotly, summary_plot_plotly_fig
11
  import plotly.tools as tls
12
+ from dash import dcc
13
+ import matplotlib.pyplot as plt
14
  import plotly.graph_objs as go
15
+
16
  try:
17
  import matplotlib.pyplot as pl
18
  from matplotlib.colors import LinearSegmentedColormap
 
22
 
23
  st.set_option('deprecation.showPyplotGlobalUse', False)
24
 
25
+ seed = 0
26
 
27
  annotations = pd.read_csv("all_genes_merged_ml_data.csv")
 
28
  annotations.fillna(0, inplace=True)
29
  annotations = annotations.set_index("Gene")
30
 
31
+ model_path = "best_model_fitted.pkl"
 
32
  with open(model_path, 'rb') as file:
33
  catboost_model = pickle.load(file)
34
 
 
35
  probabilities = catboost_model.predict_proba(annotations)
36
+ prob_df = pd.DataFrame(probabilities, index=annotations.index, columns=['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely'])
 
 
 
 
 
 
 
37
  df_total = pd.concat([prob_df, annotations], axis=1)
38
 
39
+ # Create tabs for navigation
40
+ with st.sidebar:
41
+ st.sidebar.title("Navigation")
42
+ tab = st.sidebar.radio("Go to", ("Gene Prioritisation", "Interactive SHAP Plot", "Supervised SHAP Clustering"))
43
 
44
  st.title('Blood Pressure Gene Prioritisation Post-GWAS')
45
+ st.markdown("""A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.""")
 
 
 
 
 
 
46
 
47
+ # Define a function to collect genes from input
48
+ collect_genes = lambda x: [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""]
49
  input_gene_list = st.text_input("Input a list of multiple HGNC genes (enter comma separated):")
50
  gene_list = collect_genes(input_gene_list)
51
  explainer = shap.TreeExplainer(catboost_model)
52
 
53
  @st.cache_data
54
  def convert_df(df):
55
+ return df.to_csv(index=False).encode('utf-8')
56
 
57
  probability_columns = ['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely']
58
  features_list = [column for column in df_total.columns if column not in probability_columns]
59
  features = df_total[features_list]
60
 
61
+ # Page 1: Gene Prioritisation
62
+ if tab == "Gene Prioritisation":
63
+ if len(gene_list) > 1:
64
+ df = df_total[df_total.index.isin(gene_list)]
65
+ df['Gene'] = df.index
66
+ df.reset_index(drop=True, inplace=True)
67
+
68
+ required_columns = ['Gene'] + probability_columns + [column for column in df.columns if column not in probability_columns and column != 'Gene']
69
+ df = df[required_columns]
70
+ st.dataframe(df)
71
+
72
+ output = df[['Gene'] + probability_columns]
73
+ csv = convert_df(output)
74
+ st.download_button("Download Gene Prioritisation", csv, "bp_gene_prioritisation.csv", "text/csv", key='download-csv')
75
+
76
+ df_shap = df.drop(columns=probability_columns + ['Gene'])
77
+ shap_values = explainer.shap_values(df_shap)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ col1, col2 = st.columns(2)
80
+ class_names = ["Most likely", "Probable", "Least likely"]
81
 
82
+ with col1:
83
+ st.subheader("Global SHAP Summary Plot")
84
+ shap.summary_plot(shap_values, df_shap, plot_type="bar", class_names=class_names)
85
+ st.pyplot(bbox_inches='tight', clear_figure=True)
86
 
87
+ with col2:
88
+ st.subheader(f"{class_names[0]} Gene Prediction")
89
+ shap.summary_plot(shap_values[0], df_shap)
90
+ st.pyplot(bbox_inches='tight', clear_figure=True)
91
+
92
+ col3, col4 = st.columns(2)
93
+
94
+ with col3:
95
+ st.subheader(f"{class_names[1]} Gene Prediction")
96
+ shap.summary_plot(shap_values[1], df_shap)
97
+ st.pyplot(bbox_inches='tight', clear_figure=True)
98
+
99
+ with col4:
100
+ st.subheader(f"{class_names[2]} Gene Prediction")
101
+ shap.summary_plot(shap_values[2], df_shap)
102
+ st.pyplot(bbox_inches='tight', clear_figure=True)
103
 
 
 
 
104
  else:
105
+ pass
106
+
107
+ input_gene = st.text_input("Input an individual HGNC gene:")
108
+ if input_gene:
109
+ df2 = df_total[df_total.index == input_gene]
110
+ class_names = ["Most likely", "Probable", "Least likely"]
111
+ if not df2.empty:
112
+ df2['Gene'] = df2.index
113
+ df2.reset_index(drop=True, inplace=True)
114
 
115
+ required_columns = ['Gene'] + probability_columns + [col for col in df2.columns if col not in probability_columns and col != 'Gene']
116
+ df2 = df2[required_columns]
117
+ st.dataframe(df2)
118
+
119
+ if ' ' in input_gene or ',' in input_gene:
120
+ st.write('Input Error: Please input only a single HGNC gene name with no white spaces or commas.')
121
+ else:
122
+ df2_shap = df_total.loc[[input_gene], [col for col in df_total.columns if col not in probability_columns + ['Gene']]]
123
+ print(df2_shap.columns)
124
+ shap_values = explainer.shap_values(df2_shap)
125
+ shap.getjs()
126
+
127
+ for i in range(3):
128
+ st.subheader(f"Force Plot for {class_names[i]} Prediction")
129
+ force_plot = shap.force_plot(
130
+ explainer.expected_value[i],
131
+ shap_values[i],
132
+ df2_shap,
133
+ matplotlib=True,
134
+ show=False
135
+ )
136
+ st.pyplot(fig=force_plot)
137
+ else:
138
+ st.write("Gene not found in the dataset.")
139
+ else:
140
+ pass
141
+
142
+ st.markdown("""
143
+ ### Total Gene Prioritisation Results for All Genes:
144
+ """)
145
+
146
+ df_total_output = df_total
147
+ df_total_output['Gene'] = df_total_output.index
148
+ #df_total_output.reset_index(drop=True, inplace=True)
149
+ st.dataframe(df_total_output)
150
+ csv = convert_df(df_total_output)
151
+ st.download_button("Download Gene Prioritisation", csv, "all_genes_bp_prioritisation.csv", "text/csv", key='download-all-csv')
152
+
153
+ # Page 2: Interactive SHAP Plot
154
+
155
+ elif tab == "Interactive SHAP Plot":
156
+ st.title("Interactive SHAP Plot")
157
+ if len(gene_list) > 1:
158
+ df = df_total[df_total.index.isin(gene_list)]
159
+ df['Gene'] = df.index
160
+ df.reset_index(drop=True, inplace=True)
161
+
162
+ required_columns = ['Gene'] + probability_columns + [column for column in df.columns if column not in probability_columns and column != 'Gene']
163
+ df = df[required_columns]
164
+ st.dataframe(df)
165
+
166
+ output = df[['Gene'] + probability_columns]
167
+ csv = convert_df(output)
168
+ st.download_button("Download Gene Prioritisation", csv, "bp_gene_prioritisation.csv", "text/csv", key='download-csv')
169
+
170
+ df_shap = df.drop(columns=probability_columns + ['Gene'])
171
+ shap_values = explainer.shap_values(df_shap)
172
+
173
+ # Use shap's summary_plot function for interactivity
174
+ # summary_plot = shap.summary_plot(shap_values[0], df_shap, plot_type='interactive', max_display=10)
175
+ summary_plot = summary_plot_plotly_fig(df_shap, shap_values[0], max_display=10)
176
+ st.pyplot(summary_plot)
177
+ st.caption("SHAP Summary Plot of All Input Genes")
178
+
179
 
180
+ # Page 3: Supervised SHAP Clustering
181
+ elif tab == "Supervised SHAP Clustering":
182
+ st.title("Supervised SHAP Clustering")
183
+ # Add your code here to implement supervised SHAP clustering
184
+
 
 
 
 
 
 
 
 
 
dynamic_shap_plot.py DELETED
@@ -1,118 +0,0 @@
1
- from shap_plots import shap_summary_plot, shap_dependence_plot
2
- import plotly.tools as tls
3
- import dash_core_components as dcc
4
- import pandas as pd
5
- import numpy as np
6
- import xgboost
7
- import shap
8
- import matplotlib
9
- import plotly.graph_objs as go
10
- try:
11
- import matplotlib.pyplot as pl
12
- from matplotlib.colors import LinearSegmentedColormap
13
- from matplotlib.ticker import MaxNLocator
14
- except ImportError:
15
- pass
16
- from sklearn import preprocessing
17
-
18
- cdict1 = {
19
- 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
20
- (1.0, 0.9607843137254902, 0.9607843137254902)),
21
-
22
- 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
23
- (1.0, 0.15294117647058825, 0.15294117647058825)),
24
-
25
- 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
26
- (1.0, 0.3411764705882353, 0.3411764705882353)),
27
-
28
- 'alpha': ((0.0, 1, 1),
29
- (0.5, 1, 1),
30
- (1.0, 1, 1))
31
- } # #1E88E5 -> #ff0052
32
- red_blue = LinearSegmentedColormap('RedBlue', cdict1)
33
-
34
- def matplotlib_to_plotly(cmap, pl_entries):
35
- h = 1.0/(pl_entries-1)
36
- pl_colorscale = []
37
-
38
- for k in range(pl_entries):
39
- C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255))
40
- pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
41
-
42
- return pl_colorscale
43
-
44
- red_blue = matplotlib_to_plotly(red_blue, 255)
45
-
46
- def summary_plot_plotly_fig(shap_values, df_shap, feature_names, max_display = 8):
47
- #data = pd.read_csv(dataset, encoding="ISO-8859-1")
48
- #X = data.drop(['target column'], axis=1)
49
-
50
- #y = data[target]
51
- #y = y/max(y)
52
-
53
- #X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
54
-
55
- #X_train.fillna((-999), inplace=True)
56
- #X_test.fillna((-999), inplace=True)
57
-
58
- #_, shap_values, feature_names = train_model_and_return_shap_values(X, y, target)
59
-
60
- mpl_fig = shap_summary_plot(shap_values, df_shap, feature_names=feature_names, max_display=20)
61
-
62
- plotly_fig = tls.mpl_to_plotly(mpl_fig)
63
-
64
- plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}
65
-
66
- feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
67
- feature_order = feature_order[-min(max_display, len(feature_order)):]
68
- text = [df_shap.index[i] for i in df_shap.index]
69
- text = iter(text)
70
-
71
- for i in range(1, len(plotly_fig['data']), 2):
72
- t = text.__next__()
73
- plotly_fig['data'][i]['name'] = ''
74
- plotly_fig['data'][i]['text'] = t
75
- plotly_fig['data'][i]['hoverinfo'] = 'text'
76
- #plotly_fig['data'][i]['text'] = df_shap.index
77
- plotly_fig['data'][i]['y'] = feature_names[feature_order]
78
-
79
-
80
- colorbar_trace = go.Scatter(x=[None],
81
- y=[None],
82
- mode='markers',
83
- marker=dict(
84
- colorscale=red_blue,
85
- showscale=True,
86
- cmin=-5,
87
- cmax=5,
88
- colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0)
89
- ),
90
- hoverinfo='none'
91
- )
92
-
93
- plotly_fig['layout']['showlegend'] = False
94
- plotly_fig['layout']['hovermode'] = 'closest'
95
- plotly_fig['layout']['height']=600
96
- plotly_fig['layout']['width']=500
97
-
98
- plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False)
99
- plotly_fig['layout']['yaxis'].update(dict(visible=True))
100
- plotly_fig.add_trace(colorbar_trace)
101
- plotly_fig.layout.update(
102
- annotations=[dict(
103
- x=1.18,
104
- align="right",
105
- valign="top",
106
- text='Gene',
107
- showarrow=False,
108
- xref="paper",
109
- yref="paper",
110
- xanchor="right",
111
- yanchor="middle",
112
- textangle=-90,
113
- font=dict(family='Calibri', size=14)
114
- )
115
- ],
116
- margin=dict(t=20)
117
- )
118
- return plotly_fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dynamic_shap_plots.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shap_plots import shap_summary_plot, shap_dependence_plot
2
+ import plotly.tools as tls
3
+ import dash_core_components as dcc
4
+ import pandas as pd
5
+ from sklearn.model_selection import train_test_split
6
+ import numpy as np
7
+ import xgboost
8
+ import shap
9
+ import matplotlib
10
+ import plotly.graph_objs as go
11
+ try:
12
+ import matplotlib.pyplot as pl
13
+ from matplotlib.colors import LinearSegmentedColormap
14
+ from matplotlib.ticker import MaxNLocator
15
+ except ImportError:
16
+ pass
17
+ from sklearn import preprocessing
18
+
19
+ cdict1 = {
20
+ 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
21
+ (1.0, 0.9607843137254902, 0.9607843137254902)),
22
+
23
+ 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
24
+ (1.0, 0.15294117647058825, 0.15294117647058825)),
25
+
26
+ 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
27
+ (1.0, 0.3411764705882353, 0.3411764705882353)),
28
+
29
+ 'alpha': ((0.0, 1, 1),
30
+ (0.5, 1, 1),
31
+ (1.0, 1, 1))
32
+ } # #1E88E5 -> #ff0052
33
+ red_blue = LinearSegmentedColormap('RedBlue', cdict1)
34
+
35
+ def matplotlib_to_plotly(cmap, pl_entries):
36
+ h = 1.0/(pl_entries-1)
37
+ pl_colorscale = []
38
+
39
+ for k in range(pl_entries):
40
+ C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255))
41
+ pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
42
+
43
+ return pl_colorscale
44
+
45
+ red_blue = matplotlib_to_plotly(red_blue, 255)
46
+
47
+ def summary_plot_plotly_fig(dataset, shap_values, target='target column', max_display = 20):
48
+ feature_names=dataset.columns
49
+ mpl_fig = shap_summary_plot(shap_values, dataset, feature_names=feature_names, max_display=20)
50
+
51
+ plotly_fig = tls.mpl_to_plotly(mpl_fig)
52
+
53
+ plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}
54
+
55
+ feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
56
+ feature_order = feature_order[-min(max_display, len(feature_order)):]
57
+ text = [feature_names[i] for i in feature_order]
58
+ text = iter(text)
59
+
60
+ for i in range(1, len(plotly_fig['data']), 2):
61
+ t = text.__next__()
62
+ plotly_fig['data'][i]['name'] = ''
63
+ plotly_fig['data'][i]['text'] = t
64
+ plotly_fig['data'][i]['hoverinfo'] = 'text'
65
+
66
+ colorbar_trace = go.Scatter(x=[None],
67
+ y=[None],
68
+ mode='markers',
69
+ marker=dict(
70
+ colorscale=red_blue,
71
+ showscale=True,
72
+ cmin=-5,
73
+ cmax=5,
74
+ colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0)
75
+ ),
76
+ hoverinfo='none'
77
+ )
78
+
79
+ plotly_fig['layout']['showlegend'] = False
80
+ plotly_fig['layout']['hovermode'] = 'closest'
81
+ plotly_fig['layout']['height']=600
82
+ plotly_fig['layout']['width']=500
83
+
84
+ plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False)
85
+ plotly_fig['layout']['yaxis'].update(dict(visible=False))
86
+ plotly_fig.add_trace(colorbar_trace)
87
+ plotly_fig.layout.update(
88
+ annotations=[dict(
89
+ x=1.18,
90
+ align="right",
91
+ valign="top",
92
+ text='Feature value',
93
+ showarrow=False,
94
+ xref="paper",
95
+ yref="paper",
96
+ xanchor="right",
97
+ yanchor="middle",
98
+ textangle=-90,
99
+ font=dict(family='Calibri', size=14)
100
+ )
101
+ ],
102
+ margin=dict(t=20)
103
+ )
104
+ return plotly_fig
105
+
106
+ def train_model_and_return_shap_values(X, y, target):
107
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
108
+
109
+ X_train.fillna((-999), inplace=True)
110
+ X_test.fillna((-999), inplace=True)
111
+
112
+ # Some of values are float or integer and some object. This is why we need to cast them:
113
+ for f in X_train.columns:
114
+ if X_train[f].dtype=='object':
115
+ lbl = preprocessing.LabelEncoder()
116
+ lbl.fit(list(X_train[f].values))
117
+ X_train[f] = lbl.transform(list(X_train[f].values))
118
+
119
+ for f in X_test.columns:
120
+ if X_test[f].dtype=='object':
121
+ lbl = preprocessing.LabelEncoder()
122
+ lbl.fit(list(X_test[f].values))
123
+ X_test[f] = lbl.transform(list(X_test[f].values))
124
+
125
+ X_train=np.array(X_train)
126
+ X_test=np.array(X_test)
127
+ X_train = X_train.astype(float)
128
+ X_test = X_test.astype(float)
129
+
130
+ d_train = xgboost.DMatrix(X_train, label=y_train, feature_names=list(X))
131
+ d_test = xgboost.DMatrix(X_test, label=y_test, feature_names=list(X))
132
+
133
+ # train the model
134
+ params = {
135
+ "eta": 0.01,
136
+ "subsample": 0.5,
137
+ "base_score": np.mean(y_train),
138
+ "silent": 1
139
+ }
140
+
141
+ model = xgboost.train(params, d_train, 5000, evals = [(d_test, "test")], verbose_eval=None, early_stopping_rounds=50)
142
+ feature_names = model.feature_names
143
+ shap_values = shap.TreeExplainer(model).shap_values(pd.DataFrame(X_train, columns=X.columns))
144
+ return model, shap_values, feature_names
145
+
146
+ def dependence_plot_to_plotly_fig(dataset, target='target column', max_display=10):
147
+ data = pd.read_csv(dataset, encoding="ISO-8859-1")
148
+ X = data.drop(['target column'], axis=1)
149
+ y = data[target]
150
+ y = y/max(y)
151
+
152
+ xgb_full = xgboost.DMatrix(X, label=y)
153
+
154
+ # create a train/test split
155
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
156
+ xgb_train = xgboost.DMatrix(X_train, label=y_train)
157
+ xgb_test = xgboost.DMatrix(X_test, label=y_test)
158
+
159
+ # use validation set to choose # of trees
160
+ params = {
161
+ # "eta": 0.002,
162
+ # "max_depth": 3,
163
+ # "subsample": 0.5,
164
+ "silent": 1
165
+ }
166
+ model_train = xgboost.train(params, xgb_train, 3000, evals = [(xgb_test, "test")], verbose_eval=None)
167
+
168
+ # train final model on the full data set
169
+ params = {
170
+ # "eta": 0.002,
171
+ # "max_depth": 3,
172
+ # "subsample": 0.5,
173
+ "silent": 1
174
+ }
175
+ model = xgboost.train(params, xgb_full, 1500, evals = [(xgb_full, "test")], verbose_eval=None)
176
+ features = model.feature_names
177
+ shap_values = shap.TreeExplainer(model).shap_values(X)
178
+
179
+ feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
180
+ feature_order = feature_order[-min(max_display, len(feature_order)):]
181
+ features = [features[i] for i in feature_order[::-1]]
182
+
183
+ lis = []
184
+ for i in features:
185
+ mpl_fig, interaction_index = shap_dependence_plot(i, shap_values, X)
186
+ plotly_fig = tls.mpl_to_plotly(mpl_fig)
187
+
188
+ # The x-tick labels start by default from 0, which is not necessarily the min value of the feature.
189
+ # So, we need to increment the x-tick labels by 1. But while doing so, the y-axis gets shifted.
190
+ # To prevent that, we need to manually control the x-axis range from r_min to r_max
191
+ new_x = []
192
+ for j in plotly_fig['data'][0]['x']:
193
+ new_x.append(j)
194
+
195
+ r_min = min(plotly_fig['data'][0]['x'])
196
+ r_max = max(plotly_fig['data'][0]['x'])
197
+
198
+ plotly_fig['layout']['xaxis'].update(range=[r_min-1, r_max+1])
199
+ plotly_fig['data'][0]['x'] = tuple(new_x)
200
+
201
+ # Define the colorbar
202
+ colorbar_trace = go.Scatter(x=[None],
203
+ y=[None],
204
+ mode='markers',
205
+ marker=dict(
206
+ colorscale=red_blue,
207
+ showscale=True,
208
+ colorbar=dict(thickness=5, outlinewidth=0),
209
+ color=[min(X[X.columns[interaction_index]]), max(X[X.columns[interaction_index]])],
210
+ ),
211
+ hoverinfo='none'
212
+ )
213
+
214
+ plotly_fig['layout']['showlegend'] = False
215
+ plotly_fig['layout']['hovermode'] = 'closest'
216
+ plotly_fig['layout']['height']=380
217
+ plotly_fig['layout']['width']=450
218
+ plotly_fig['layout']['xaxis'].update(zeroline=True,
219
+ showline=True,
220
+ ticklen=4,
221
+ showgrid=False,
222
+ tickmode='linear')
223
+ title = plotly_fig['layout']['yaxis']['title']
224
+ plotly_fig['layout']['yaxis'].update(title=title.split(' -')[0])
225
+
226
+ plotly_fig.add_trace(colorbar_trace)
227
+ plotly_fig.layout.update(
228
+ annotations=[dict(
229
+ x=1.23,
230
+ align="right",
231
+ valign="top",
232
+ text=X.columns[interaction_index],
233
+ showarrow=False,
234
+ xref="paper",
235
+ yref="paper",
236
+ xanchor="right",
237
+ yanchor="middle",
238
+ textangle=-90,
239
+ font=dict(family='Calibri', size=14)
240
+ )
241
+ ],
242
+ margin=dict(t=50, b=50, l=50, r=80)
243
+ )
244
+ lis.append(plotly_fig)
245
+ return lis, features
246
+
247
+ def interaction_plot_to_plotly_fig(dataset, target_col='target column', max_display=10):
248
+ data = pd.read_csv(dataset, encoding="ISO-8859-1")
249
+ X = data.drop(['target column'], axis=1)
250
+ y = data[target_col]
251
+ y = y/max(y)
252
+
253
+ xgb_full = xgboost.DMatrix(X, label=y)
254
+
255
+ # create a train/test split
256
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
257
+ xgb_train = xgboost.DMatrix(X_train, label=y_train)
258
+ xgb_test = xgboost.DMatrix(X_test, label=y_test)
259
+
260
+ # use validation set to choose # of trees
261
+ params = {
262
+ # "eta": 0.002,
263
+ # "max_depth": 3,
264
+ # "subsample": 0.5,
265
+ "silent": 1
266
+ }
267
+ model_train = xgboost.train(params, xgb_train, 3000, evals = [(xgb_test, "test")], verbose_eval=None)
268
+
269
+ # train final model on the full data set
270
+ params = {
271
+ # "eta": 0.002,
272
+ # "max_depth": 3,
273
+ # "subsample": 0.5,
274
+ "silent": 1
275
+ }
276
+ model = xgboost.train(params, xgb_full, 1500, evals = [(xgb_full, "test")], verbose_eval=None)
277
+ features = model.feature_names
278
+ shap_values = shap.TreeExplainer(model).shap_values(X)
279
+
280
+ feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
281
+ feature_order = feature_order[-min(max_display, len(feature_order)):]
282
+ features = [features[i] for i in feature_order[::-1]]
283
+
284
+ shap_interaction_values = shap.TreeExplainer(model).shap_interaction_values(X)
285
+
286
+ lis = []
287
+ for i in features:
288
+ for j in features:
289
+ mpl_fig = pl.figure()
290
+ ax = mpl_fig.add_subplot(111)
291
+ _, interaction_index = shap_dependence_plot ( (i, j), shap_interaction_values, X.iloc[:2000,:] )
292
+ plotly_fig = tls.mpl_to_plotly(mpl_fig)
293
+
294
+ r_min = min(plotly_fig['data'][0]['x'])
295
+ r_max = max(plotly_fig['data'][0]['x'])
296
+
297
+ plotly_fig['layout']['xaxis'].update(range=[r_min-1, r_max+1])
298
+ plotly_fig['layout']['showlegend'] = False
299
+ plotly_fig['layout']['hovermode'] = 'closest'
300
+ plotly_fig['layout']['height']=380
301
+ plotly_fig['layout']['width']=450
302
+ plotly_fig['layout']['xaxis'].update(zeroline=True,
303
+ showline=True,
304
+ ticklen=4,
305
+ showgrid=False,
306
+ tickmode='linear')
307
+ plotly_fig['layout']['yaxis'].update(showline=True)
308
+
309
+ if i!=j:
310
+ # plotly_fig['layout']['height']=380
311
+ plotly_fig['layout']['width']=480
312
+ plotly_fig['layout']['yaxis']['title'] = "SHAP interaction value for {} and {}".format(i.split('-')[0], j.split('-')[0])
313
+ # Define the colorbar
314
+ colorbar_trace = go.Scatter(x=[None],
315
+ y=[None],
316
+ mode='markers',
317
+ marker=dict(
318
+ colorscale=red_blue,
319
+ showscale=True,
320
+ colorbar=dict(thickness=5, outlinewidth=0),
321
+ color=[min(X[X.columns[interaction_index]]), max(X[X.columns[interaction_index]])],
322
+ ),
323
+ hoverinfo='none'
324
+ )
325
+ plotly_fig.add_trace(colorbar_trace)
326
+ plotly_fig.layout.update(
327
+ annotations=[dict(
328
+ x=1.23,
329
+ align="right",
330
+ valign="top",
331
+ text=X.columns[interaction_index],
332
+ showarrow=False,
333
+ xref="paper",
334
+ yref="paper",
335
+ xanchor="right",
336
+ yanchor="middle",
337
+ textangle=-90,
338
+ font=dict(family='Calibri', size=14)
339
+ )
340
+ ],
341
+ margin=dict(t=30, b=30, l=60, r=80)
342
+ )
343
+ else:
344
+ plotly_fig['layout']['yaxis']['title'] = "SHAP main effect value for {}".format(i.split('-')[0])
345
+ lis.append(plotly_fig)
346
+ return lis, features
requirements.txt CHANGED
@@ -3,6 +3,7 @@ numpy==1.23.4
3
  altair==5.1.2
4
  scikit-learn==1.1.3
5
  pandas
 
6
  xgboost==1.3.3
7
  shap==0.41.0
8
  plotly
 
3
  altair==5.1.2
4
  scikit-learn==1.1.3
5
  pandas
6
+ catboost
7
  xgboost==1.3.3
8
  shap==0.41.0
9
  plotly
shap_plots.py CHANGED
@@ -1,730 +1,730 @@
1
- import warnings
2
- import iml
3
- import numpy as np
4
- from iml import Instance, Model
5
- from iml.datatypes import DenseData
6
- from iml.explanations import AdditiveExplanation
7
- from iml.links import IdentityLink
8
- from scipy.stats import gaussian_kde
9
- import matplotlib
10
- try:
11
- import matplotlib.pyplot as pl
12
- from matplotlib.colors import LinearSegmentedColormap
13
- from matplotlib.ticker import MaxNLocator
14
-
15
- cdict1 = {
16
- 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
17
- (1.0, 0.9607843137254902, 0.9607843137254902)),
18
-
19
- 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
20
- (1.0, 0.15294117647058825, 0.15294117647058825)),
21
-
22
- 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
23
- (1.0, 0.3411764705882353, 0.3411764705882353)),
24
-
25
- 'alpha': ((0.0, 1, 1),
26
- (0.5, 0.3, 0.3),
27
- (1.0, 1, 1))
28
- } # #1E88E5 -> #ff0052
29
- red_blue = LinearSegmentedColormap('RedBlue', cdict1)
30
-
31
- cdict1 = {
32
- 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
33
- (1.0, 0.9607843137254902, 0.9607843137254902)),
34
-
35
- 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
36
- (1.0, 0.15294117647058825, 0.15294117647058825)),
37
-
38
- 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
39
- (1.0, 0.3411764705882353, 0.3411764705882353)),
40
-
41
- 'alpha': ((0.0, 1, 1),
42
- (0.5, 1, 1),
43
- (1.0, 1, 1))
44
- } # #1E88E5 -> #ff0052
45
- red_blue_solid = LinearSegmentedColormap('RedBlue', cdict1)
46
- except ImportError:
47
- pass
48
-
49
- labels = {
50
- 'MAIN_EFFECT': "SHAP main effect value for\n%s",
51
- 'INTERACTION_VALUE': "SHAP interaction value",
52
- 'INTERACTION_EFFECT': "SHAP interaction value for\n%s and %s",
53
- 'VALUE': "SHAP value (impact on model output)",
54
- 'VALUE_FOR': "SHAP value for\n%s",
55
- 'PLOT_FOR': "SHAP plot for %s",
56
- 'FEATURE': "Feature %s",
57
- 'FEATURE_VALUE': "Feature value",
58
- 'FEATURE_VALUE_LOW': "Low",
59
- 'FEATURE_VALUE_HIGH': "High",
60
- 'JOINT_VALUE': "Joint SHAP value"
61
- }
62
-
63
- def shap_summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type="dot",
64
- color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True,
65
- color_bar=True, auto_size_plot=True, layered_violin_max_num_bins=20):
66
- """Create a SHAP summary plot, colored by feature values when they are provided.
67
-
68
- Parameters
69
- ----------
70
- shap_values : numpy.array
71
- Matrix of SHAP values (# samples x # features)
72
-
73
- features : numpy.array or pandas.DataFrame or list
74
- Matrix of feature values (# samples x # features) or a feature_names list as shorthand
75
-
76
- feature_names : list
77
- Names of the features (length # features)
78
-
79
- max_display : int
80
- How many top features to include in the plot (default is 20, or 7 for interaction plots)
81
-
82
- plot_type : "dot" (default) or "violin"
83
- What type of summary plot to produce
84
- """
85
-
86
- assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector."
87
-
88
- # default color:
89
- if color is None:
90
- color = "coolwarm" if plot_type == 'layered_violin' else "#ff0052"
91
-
92
- # convert from a DataFrame or other types
93
- if str(type(features)) == "<class 'pandas.core.frame.DataFrame'>":
94
- if feature_names is None:
95
- feature_names = features.columns
96
- features = features.values
97
- elif str(type(features)) == "<class 'list'>":
98
- if feature_names is None:
99
- feature_names = features
100
- features = None
101
- elif (features is not None) and len(features.shape) == 1 and feature_names is None:
102
- feature_names = features
103
- features = None
104
-
105
- if feature_names is None:
106
- feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
107
-
108
- mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
109
-
110
- # plotting SHAP interaction values
111
- if len(shap_values.shape) == 3:
112
- if max_display is None:
113
- max_display = 7
114
- else:
115
- max_display = min(len(feature_names), max_display)
116
-
117
- sort_inds = np.argsort(-np.abs(shap_values[:, :-1, :-1].sum(1)).sum(0))
118
-
119
- # get plotting limits
120
- delta = 1.0 / (shap_values.shape[1] ** 2)
121
- slow = np.nanpercentile(shap_values, delta)
122
- shigh = np.nanpercentile(shap_values, 100 - delta)
123
- v = max(abs(slow), abs(shigh))
124
- slow = -0.2
125
- shigh = 0.2
126
-
127
- # mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
128
- ax = mpl_fig.subplot(1, max_display, 1)
129
- proj_shap_values = shap_values[:, sort_inds[0], np.hstack((sort_inds, len(sort_inds)))]
130
- proj_shap_values[:, 1:] *= 2 # because off diag effects are split in half
131
- shap_summary_plot(
132
- proj_shap_values, features[:, sort_inds],
133
- feature_names=feature_names[sort_inds],
134
- sort=False, show=False, color_bar=False,
135
- auto_size_plot=False,
136
- max_display=max_display
137
- )
138
- pl.xlim((slow, shigh))
139
- pl.xlabel("")
140
- title_length_limit = 11
141
- pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit))
142
- for i in range(1, max_display):
143
- ind = sort_inds[i]
144
- pl.subplot(1, max_display, i + 1)
145
- proj_shap_values = shap_values[:, ind, np.hstack((sort_inds, len(sort_inds)))]
146
- proj_shap_values *= 2
147
- proj_shap_values[:, i] /= 2 # because only off diag effects are split in half
148
- shap_summary_plot(
149
- proj_shap_values, features[:, sort_inds],
150
- sort=False,
151
- feature_names=df_shap.columns, #["" for i in range(features.shape[1])],
152
- show=False,
153
- color_bar=False,
154
- auto_size_plot=False,
155
- max_display=max_display
156
- )
157
- pl.xlim((slow, shigh))
158
- pl.xlabel("")
159
- if i == max_display // 2:
160
- pl.xlabel(labels['INTERACTION_VALUE'])
161
- pl.title(shorten_text(feature_names[ind], title_length_limit))
162
- pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
163
- pl.subplots_adjust(hspace=0, wspace=0.1)
164
- # if show:
165
- # # pl.show()
166
- return mpl_fig
167
-
168
- if max_display is None:
169
- max_display = 20
170
-
171
- if sort:
172
- # order features by the sum of their effect magnitudes
173
- feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
174
- feature_order = feature_order[-min(max_display, len(feature_order)):]
175
- else:
176
- feature_order = np.flip(np.arange(min(max_display, shap_values.shape[1] - 1)), 0)
177
-
178
- row_height = 0.4
179
- if auto_size_plot:
180
- pl.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5)
181
- pl.axvline(x=0, color="#999999", zorder=-1)
182
-
183
- if plot_type == "dot":
184
- for pos, i in enumerate(feature_order):
185
- pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
186
- shaps = shap_values[:, i]
187
- values = None if features is None else features[:, i]
188
- inds = np.arange(len(shaps))
189
- np.random.shuffle(inds)
190
- if values is not None:
191
- values = values[inds]
192
- shaps = shaps[inds]
193
- colored_feature = True
194
- try:
195
- values = np.array(values, dtype=np.float64) # make sure this can be numeric
196
- except:
197
- colored_feature = False
198
- N = len(shaps)
199
- # hspacing = (np.max(shaps) - np.min(shaps)) / 200
200
- # curr_bin = []
201
- nbins = 100
202
- quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
203
- inds = np.argsort(quant + np.random.randn(N) * 1e-6)
204
- layer = 0
205
- last_bin = -1
206
- ys = np.zeros(N)
207
- for ind in inds:
208
- if quant[ind] != last_bin:
209
- layer = 0
210
- ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
211
- layer += 1
212
- last_bin = quant[ind]
213
- ys *= 0.9 * (row_height / np.max(ys + 1))
214
-
215
- if features is not None and colored_feature:
216
- # trim the color range, but prevent the color range from collapsing
217
- vmin = np.nanpercentile(values, 5)
218
- vmax = np.nanpercentile(values, 95)
219
- if vmin == vmax:
220
- vmin = np.nanpercentile(values, 1)
221
- vmax = np.nanpercentile(values, 99)
222
- if vmin == vmax:
223
- vmin = np.min(values)
224
- vmax = np.max(values)
225
-
226
- assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!"
227
- nan_mask = np.isnan(values)
228
- pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin,
229
- vmax=vmax, s=16, alpha=alpha, linewidth=0,
230
- zorder=3, rasterized=len(shaps) > 500)
231
- pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
232
- cmap=red_blue, vmin=vmin, vmax=vmax, s=16,
233
- c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0,
234
- zorder=3, rasterized=len(shaps) > 500)
235
- else:
236
-
237
- pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3,
238
- color=color if colored_feature else "#777777", rasterized=len(shaps) > 500)
239
-
240
- elif plot_type == "violin":
241
- for pos, i in enumerate(feature_order):
242
- pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
243
-
244
- if features is not None:
245
- global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1)
246
- global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99)
247
- for pos, i in enumerate(feature_order):
248
- shaps = shap_values[:, i]
249
- shap_min, shap_max = np.min(shaps), np.max(shaps)
250
- rng = shap_max - shap_min
251
- xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100)
252
- if np.std(shaps) < (global_high - global_low) / 100:
253
- ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs)
254
- else:
255
- ds = gaussian_kde(shaps)(xs)
256
- ds /= np.max(ds) * 3
257
-
258
- values = features[:, i]
259
- window_size = max(10, len(values) // 20)
260
- smooth_values = np.zeros(len(xs) - 1)
261
- sort_inds = np.argsort(shaps)
262
- trailing_pos = 0
263
- leading_pos = 0
264
- running_sum = 0
265
- back_fill = 0
266
- for j in range(len(xs) - 1):
267
-
268
- while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]:
269
- running_sum += values[sort_inds[leading_pos]]
270
- leading_pos += 1
271
- if leading_pos - trailing_pos > 20:
272
- running_sum -= values[sort_inds[trailing_pos]]
273
- trailing_pos += 1
274
- if leading_pos - trailing_pos > 0:
275
- smooth_values[j] = running_sum / (leading_pos - trailing_pos)
276
- for k in range(back_fill):
277
- smooth_values[j - k - 1] = smooth_values[j]
278
- else:
279
- back_fill += 1
280
-
281
- vmin = np.nanpercentile(values, 5)
282
- vmax = np.nanpercentile(values, 95)
283
- if vmin == vmax:
284
- vmin = np.nanpercentile(values, 1)
285
- vmax = np.nanpercentile(values, 99)
286
- if vmin == vmax:
287
- vmin = np.min(values)
288
- vmax = np.max(values)
289
- pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=red_blue_solid, vmin=vmin, vmax=vmax,
290
- c=values, alpha=alpha, linewidth=0, zorder=1)
291
- # smooth_values -= nxp.nanpercentile(smooth_values, 5)
292
- # smooth_values /= np.nanpercentile(smooth_values, 95)
293
- smooth_values -= vmin
294
- if vmax - vmin > 0:
295
- smooth_values /= vmax - vmin
296
- for i in range(len(xs) - 1):
297
- if ds[i] > 0.05 or ds[i + 1] > 0.05:
298
- pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]],
299
- [pos - ds[i], pos - ds[i + 1]], color=red_blue_solid(smooth_values[i]),
300
- zorder=2)
301
-
302
- else:
303
- parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False,
304
- widths=0.7,
305
- showmeans=False, showextrema=False, showmedians=False)
306
-
307
- for pc in parts['bodies']:
308
- pc.set_facecolor(color)
309
- pc.set_edgecolor('none')
310
- pc.set_alpha(alpha)
311
-
312
- elif plot_type == "layered_violin": # courtesy of @kodonnell
313
- num_x_points = 200
314
- bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype(
315
- 'int') # the indices of the feature data corresponding to each bin
316
- shap_min, shap_max = np.min(shap_values[:, :-1]), np.max(shap_values[:, :-1])
317
- x_points = np.linspace(shap_min, shap_max, num_x_points)
318
-
319
- # loop through each feature and plot:
320
- for pos, ind in enumerate(feature_order):
321
- # decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles.
322
- # to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts.
323
- feature = features[:, ind]
324
- unique, counts = np.unique(feature, return_counts=True)
325
- if unique.shape[0] <= layered_violin_max_num_bins:
326
- order = np.argsort(unique)
327
- thesebins = np.cumsum(counts[order])
328
- thesebins = np.insert(thesebins, 0, 0)
329
- else:
330
- thesebins = bins
331
- nbins = thesebins.shape[0] - 1
332
- # order the feature data so we can apply percentiling
333
- order = np.argsort(feature)
334
- # x axis is located at y0 = pos, with pos being there for offset
335
- y0 = np.ones(num_x_points) * pos
336
- # calculate kdes:
337
- ys = np.zeros((nbins, num_x_points))
338
- for i in range(nbins):
339
- # get shap values in this bin:
340
- shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind]
341
- # if there's only one element, then we can't
342
- if shaps.shape[0] == 1:
343
- warnings.warn(
344
- "not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot."
345
- % (i, feature_names[ind]))
346
- # to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's
347
- # nothing to do if i == 0
348
- if i > 0:
349
- ys[i, :] = ys[i - 1, :]
350
- continue
351
- # save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors
352
- ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points)
353
- # scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will
354
- # do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1%
355
- # female, we want the 1% to appear a lot smaller.
356
- size = thesebins[i + 1] - thesebins[i]
357
- bin_size_if_even = features.shape[0] / nbins
358
- relative_bin_size = size / bin_size_if_even
359
- ys[i, :] *= relative_bin_size
360
- # now plot 'em. We don't plot the individual strips, as this can leave whitespace between them.
361
- # instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no
362
- # whitespace
363
- ys = np.cumsum(ys, axis=0)
364
- width = 0.8
365
- scale = ys.max() * 2 / width # 2 is here as we plot both sides of x axis
366
- for i in range(nbins - 1, -1, -1):
367
- y = ys[i, :] / scale
368
- c = pl.get_cmap(color)(i / (
369
- nbins - 1)) if color in pl.cm.datad else color # if color is a cmap, use it, otherwise use a color
370
- pl.fill_between(x_points, pos - y, pos + y, facecolor=c)
371
- pl.xlim(shap_min, shap_max)
372
-
373
- # draw the color bar
374
- if color_bar and features is not None and (plot_type != "layered_violin" or color in pl.cm.datad):
375
- import matplotlib.cm as cm
376
- m = cm.ScalarMappable(cmap=red_blue_solid if plot_type != "layered_violin" else pl.get_cmap(color))
377
- m.set_array([0, 1])
378
- cb = pl.colorbar(m, ticks=[0, 1], aspect=1000)
379
- cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
380
- cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0)
381
- cb.ax.tick_params(labelsize=11, length=0)
382
- cb.set_alpha(1)
383
- cb.outline.set_visible(False)
384
- bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
385
- cb.ax.set_aspect((bbox.height - 0.9) * 20)
386
- # cb.draw_all()
387
-
388
- pl.gca().xaxis.set_ticks_position('bottom')
389
- pl.gca().yaxis.set_ticks_position('none')
390
- pl.gca().spines['right'].set_visible(False)
391
- pl.gca().spines['top'].set_visible(False)
392
- pl.gca().spines['left'].set_visible(False)
393
- pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
394
- pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13)
395
- pl.gca().tick_params('y', length=20, width=0.5, which='major')
396
- pl.gca().tick_params('x', labelsize=11)
397
- pl.ylim(-1, len(feature_order))
398
- pl.xlabel(labels['VALUE'], fontsize=13)
399
- pl.tight_layout()
400
- # if show:
401
- # pl.show()
402
- return mpl_fig
403
-
404
-
405
-
406
-
407
-
408
-
409
- def approx_interactions(index, shap_values, X):
410
- """ Order other features by how much interaction they seem to have with the feature at the given index.
411
-
412
- This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction
413
- index values for SHAP see the interaction_contribs option implemented in XGBoost.
414
- """
415
-
416
- if X.shape[0] > 10000:
417
- a = np.arange(X.shape[0])
418
- np.random.shuffle(a)
419
- inds = a[:10000]
420
- else:
421
- inds = np.arange(X.shape[0])
422
-
423
- x = X[inds, index]
424
- srt = np.argsort(x)
425
- shap_ref = shap_values[inds, index]
426
- shap_ref = shap_ref[srt]
427
- inc = max(min(int(len(x) / 10.0), 50), 1)
428
- interactions = []
429
- for i in range(X.shape[1]):
430
- val_other = X[inds, i][srt].astype(np.float)
431
- v = 0.0
432
- if not (i == index or np.sum(np.abs(val_other)) < 1e-8):
433
- for j in range(0, len(x), inc):
434
- if np.std(val_other[j:j + inc]) > 0 and np.std(shap_ref[j:j + inc]) > 0:
435
- v += abs(np.corrcoef(shap_ref[j:j + inc], val_other[j:j + inc])[0, 1])
436
- interactions.append(v)
437
-
438
- return np.argsort(-np.abs(interactions))
439
-
440
-
441
-
442
-
443
-
444
-
445
-
446
- def shap_dependence_plot(ind, shap_values, features, feature_names=None, display_features=None,
447
- interaction_index="auto", color="#1E88E5", axis_color="#333333",
448
- dot_size=16, alpha=1, title=None, show=True):
449
- """
450
- Create a SHAP dependence plot, colored by an interaction feature.
451
-
452
- Parameters
453
- ----------
454
- ind : int
455
- Index of the feature to plot.
456
-
457
- shap_values : numpy.array
458
- Matrix of SHAP values (# samples x # features)
459
-
460
- features : numpy.array or pandas.DataFrame
461
- Matrix of feature values (# samples x # features)
462
-
463
- feature_names : list
464
- Names of the features (length # features)
465
-
466
- display_features : numpy.array or pandas.DataFrame
467
- Matrix of feature values for visual display (such as strings instead of coded values)
468
-
469
- interaction_index : "auto", None, or int
470
- The index of the feature used to color the plot.
471
- """
472
-
473
- # convert from DataFrames if we got any
474
- if str(type(features)).endswith("'pandas.core.frame.DataFrame'>"):
475
- if feature_names is None:
476
- feature_names = features.columns
477
- features = features.values
478
- if str(type(display_features)).endswith("'pandas.core.frame.DataFrame'>"):
479
- if feature_names is None:
480
- feature_names = display_features.columns
481
- display_features = display_features.values
482
- elif display_features is None:
483
- display_features = features
484
-
485
- if feature_names is None:
486
- feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
487
-
488
- # allow vectors to be passed
489
- if len(shap_values.shape) == 1:
490
- shap_values = np.reshape(shap_values, len(shap_values), 1)
491
- if len(features.shape) == 1:
492
- features = np.reshape(features, len(features), 1)
493
-
494
- def convert_name(ind):
495
- if type(ind) == str:
496
- nzinds = np.where(feature_names == ind)[0]
497
- if len(nzinds) == 0:
498
- print("Could not find feature named: " + ind)
499
- return None
500
- else:
501
- return nzinds[0]
502
- else:
503
- return ind
504
-
505
- ind = convert_name(ind)
506
-
507
- mpl_fig = pl.gcf()
508
- ax = mpl_fig.gca()
509
-
510
- # plotting SHAP interaction values
511
- if len(shap_values.shape) == 3 and len(ind) == 2:
512
- ind1 = convert_name(ind[0])
513
- ind2 = convert_name(ind[1])
514
- if ind1 == ind2:
515
- proj_shap_values = shap_values[:, ind2, :]
516
- else:
517
- proj_shap_values = shap_values[:, ind2, :] * 2 # off-diag values are split in half
518
-
519
- # TODO: remove recursion; generally the functions should be shorter for more maintainable code
520
- return shap_dependence_plot(
521
- ind1, proj_shap_values, features, feature_names=feature_names,
522
- interaction_index=ind2, display_features=display_features, show=False
523
- )
524
-
525
- assert shap_values.shape[0] == features.shape[0], \
526
- "'shap_values' and 'features' values must have the same number of rows!"
527
- assert shap_values.shape[1] == features.shape[1], \
528
- "'shap_values' must have the same number of columns as 'features'!"
529
-
530
- # get both the raw and display feature values
531
- xv = features[:, ind]
532
- xd = display_features[:, ind]
533
- s = shap_values[:, ind]
534
- if type(xd[0]) == str:
535
- name_map = {}
536
- for i in range(len(xv)):
537
- name_map[xd[i]] = xv[i]
538
- xnames = list(name_map.keys())
539
-
540
- # allow a single feature name to be passed alone
541
- if type(feature_names) == str:
542
- feature_names = [feature_names]
543
- name = feature_names[ind]
544
-
545
- # guess what other feature as the stongest interaction with the plotted feature
546
- if interaction_index == "auto":
547
- interaction_index = approx_interactions(ind, shap_values, features)[0]
548
- interaction_index = convert_name(interaction_index)
549
- categorical_interaction = False
550
-
551
- # get both the raw and display color values
552
- if interaction_index is not None:
553
- cv = features[:, interaction_index]
554
- cd = display_features[:, interaction_index]
555
- clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
556
- chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
557
- if type(cd[0]) == str:
558
- cname_map = {}
559
- for i in range(len(cv)):
560
- cname_map[cd[i]] = cv[i]
561
- cnames = list(cname_map.keys())
562
- categorical_interaction = True
563
- elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
564
- categorical_interaction = True
565
-
566
- # discritize colors for categorical features
567
- color_norm = None
568
- if categorical_interaction and clow != chigh:
569
- bounds = np.linspace(clow, chigh, chigh - clow + 2)
570
- color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
571
-
572
- # the actual scatter plot, TODO: adapt the dot_size to the number of data points?
573
- if interaction_index is not None:
574
- pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
575
- alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
576
- else:
577
- pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
578
- alpha=alpha, rasterized=len(xv) > 500)
579
-
580
- if interaction_index != ind and interaction_index is not None:
581
- # draw the color bar
582
- if type(cd[0]) == str:
583
- tick_positions = [cname_map[n] for n in cnames]
584
- if len(tick_positions) == 2:
585
- tick_positions[0] -= 0.25
586
- tick_positions[1] += 0.25
587
- cb = pl.colorbar(ticks=tick_positions)
588
- cb.set_ticklabels(cnames)
589
- else:
590
- cb = pl.colorbar()
591
-
592
- cb.set_label(feature_names[interaction_index], size=13)
593
- cb.ax.tick_params(labelsize=11)
594
- if categorical_interaction:
595
- cb.ax.tick_params(length=0)
596
- cb.set_alpha(1)
597
- cb.outline.set_visible(False)
598
- bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
599
- cb.ax.set_aspect((bbox.height - 0.7) * 20)
600
-
601
- # make the plot more readable
602
- if interaction_index != ind:
603
- pl.gcf().set_size_inches(7.5, 5)
604
- else:
605
- pl.gcf().set_size_inches(6, 5)
606
- # pl.xlabel(name, color=axis_color, fontsize=13)
607
- # pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
608
- if title is not None:
609
- pl.title(title, color=axis_color, fontsize=13)
610
- pl.gca().xaxis.set_ticks_position('bottom')
611
- pl.gca().yaxis.set_ticks_position('left')
612
- pl.gca().spines['right'].set_visible(False)
613
- pl.gca().spines['top'].set_visible(False)
614
- pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
615
- for spine in pl.gca().spines.values():
616
- spine.set_edgecolor(axis_color)
617
- if type(xd[0]) == str:
618
- pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
619
- # if show:
620
- # pl.show()
621
-
622
-
623
- if ind1 == ind2:
624
- pl.ylabel(labels['MAIN_EFFECT'] % feature_names[ind1])
625
- else:
626
- pl.ylabel(labels['INTERACTION_EFFECT'] % (feature_names[ind1], feature_names[ind2]))
627
-
628
- return mpl_fig, interaction_index
629
-
630
-
631
- # # if show:
632
- # # pl.show()
633
- # return
634
- # return mpl_fig
635
-
636
- # assert shap_values.shape[0] == features.shape[0], "'shap_values' and 'features' values must have the same number of rows!"
637
- # assert shap_values.shape[1] == features.shape[1] + 1, "'shap_values' must have one more column than 'features'!"
638
-
639
- # get both the raw and display feature values
640
- xv = features[:, ind]
641
- xd = display_features[:, ind]
642
- s = shap_values[:, ind]
643
- if type(xd[0]) == str:
644
- name_map = {}
645
- for i in range(len(xv)):
646
- name_map[xd[i]] = xv[i]
647
- xnames = list(name_map.keys())
648
-
649
- # allow a single feature name to be passed alone
650
- if type(feature_names) == str:
651
- feature_names = [feature_names]
652
- name = feature_names[ind]
653
-
654
- # guess what other feature as the stongest interaction with the plotted feature
655
- if interaction_index == "auto":
656
- interaction_index = approx_interactions(ind, shap_values, features)[0]
657
- interaction_index = convert_name(interaction_index)
658
- categorical_interaction = False
659
-
660
- # get both the raw and display color values
661
- if interaction_index is not None:
662
- cv = features[:, interaction_index]
663
- cd = display_features[:, interaction_index]
664
- clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
665
- chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
666
- if type(cd[0]) == str:
667
- cname_map = {}
668
- for i in range(len(cv)):
669
- cname_map[cd[i]] = cv[i]
670
- cnames = list(cname_map.keys())
671
- categorical_interaction = True
672
- elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
673
- categorical_interaction = True
674
-
675
- # discritize colors for categorical features
676
- color_norm = None
677
- if categorical_interaction and clow != chigh:
678
- bounds = np.linspace(clow, chigh, chigh - clow + 2)
679
- color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
680
-
681
- # the actual scatter plot, TODO: adapt the dot_size to the number of data points?
682
- if interaction_index is not None:
683
- pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
684
- alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
685
- else:
686
- pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
687
- alpha=alpha, rasterized=len(xv) > 500)
688
-
689
- if interaction_index != ind and interaction_index is not None:
690
- # draw the color bar
691
- if type(cd[0]) == str:
692
- tick_positions = [cname_map[n] for n in cnames]
693
- if len(tick_positions) == 2:
694
- tick_positions[0] -= 0.25
695
- tick_positions[1] += 0.25
696
- cb = pl.colorbar(ticks=tick_positions)
697
- cb.set_ticklabels(cnames)
698
- else:
699
- cb = pl.colorbar()
700
-
701
- cb.set_label(feature_names[interaction_index], size=13)
702
- cb.ax.tick_params(labelsize=11)
703
- if categorical_interaction:
704
- cb.ax.tick_params(length=0)
705
- cb.set_alpha(1)
706
- cb.outline.set_visible(False)
707
- bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
708
- cb.ax.set_aspect((bbox.height - 0.7) * 20)
709
-
710
- # make the plot more readable
711
- if interaction_index != ind:
712
- pl.gcf().set_size_inches(7.5, 5)
713
- else:
714
- pl.gcf().set_size_inches(6, 5)
715
- pl.xlabel(name, color=axis_color, fontsize=13)
716
- pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
717
- if title is not None:
718
- pl.title(title, color=axis_color, fontsize=13)
719
- pl.gca().xaxis.set_ticks_position('bottom')
720
- pl.gca().yaxis.set_ticks_position('left')
721
- pl.gca().spines['right'].set_visible(False)
722
- pl.gca().spines['top'].set_visible(False)
723
- pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
724
- for spine in pl.gca().spines.values():
725
- spine.set_edgecolor(axis_color)
726
- if type(xd[0]) == str:
727
- pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
728
- # if show:
729
- # pl.show()
730
  return mpl_fig, interaction_index
 
1
+ import warnings
2
+ import iml
3
+ import numpy as np
4
+ from iml import Instance, Model
5
+ from iml.datatypes import DenseData
6
+ from iml.explanations import AdditiveExplanation
7
+ from iml.links import IdentityLink
8
+ from scipy.stats import gaussian_kde
9
+ import matplotlib
10
+ try:
11
+ import matplotlib.pyplot as pl
12
+ from matplotlib.colors import LinearSegmentedColormap
13
+ from matplotlib.ticker import MaxNLocator
14
+
15
+ cdict1 = {
16
+ 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
17
+ (1.0, 0.9607843137254902, 0.9607843137254902)),
18
+
19
+ 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
20
+ (1.0, 0.15294117647058825, 0.15294117647058825)),
21
+
22
+ 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
23
+ (1.0, 0.3411764705882353, 0.3411764705882353)),
24
+
25
+ 'alpha': ((0.0, 1, 1),
26
+ (0.5, 0.3, 0.3),
27
+ (1.0, 1, 1))
28
+ } # #1E88E5 -> #ff0052
29
+ red_blue = LinearSegmentedColormap('RedBlue', cdict1)
30
+
31
+ cdict1 = {
32
+ 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
33
+ (1.0, 0.9607843137254902, 0.9607843137254902)),
34
+
35
+ 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
36
+ (1.0, 0.15294117647058825, 0.15294117647058825)),
37
+
38
+ 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
39
+ (1.0, 0.3411764705882353, 0.3411764705882353)),
40
+
41
+ 'alpha': ((0.0, 1, 1),
42
+ (0.5, 1, 1),
43
+ (1.0, 1, 1))
44
+ } # #1E88E5 -> #ff0052
45
+ red_blue_solid = LinearSegmentedColormap('RedBlue', cdict1)
46
+ except ImportError:
47
+ pass
48
+
49
+ labels = {
50
+ 'MAIN_EFFECT': "SHAP main effect value for\n%s",
51
+ 'INTERACTION_VALUE': "SHAP interaction value",
52
+ 'INTERACTION_EFFECT': "SHAP interaction value for\n%s and %s",
53
+ 'VALUE': "SHAP value (impact on model output)",
54
+ 'VALUE_FOR': "SHAP value for\n%s",
55
+ 'PLOT_FOR': "SHAP plot for %s",
56
+ 'FEATURE': "Feature %s",
57
+ 'FEATURE_VALUE': "Feature value",
58
+ 'FEATURE_VALUE_LOW': "Low",
59
+ 'FEATURE_VALUE_HIGH': "High",
60
+ 'JOINT_VALUE': "Joint SHAP value"
61
+ }
62
+
63
+ def shap_summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type="dot",
64
+ color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True,
65
+ color_bar=True, auto_size_plot=True, layered_violin_max_num_bins=20):
66
+ """Create a SHAP summary plot, colored by feature values when they are provided.
67
+
68
+ Parameters
69
+ ----------
70
+ shap_values : numpy.array
71
+ Matrix of SHAP values (# samples x # features)
72
+
73
+ features : numpy.array or pandas.DataFrame or list
74
+ Matrix of feature values (# samples x # features) or a feature_names list as shorthand
75
+
76
+ feature_names : list
77
+ Names of the features (length # features)
78
+
79
+ max_display : int
80
+ How many top features to include in the plot (default is 20, or 7 for interaction plots)
81
+
82
+ plot_type : "dot" (default) or "violin"
83
+ What type of summary plot to produce
84
+ """
85
+
86
+ assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector."
87
+
88
+ # default color:
89
+ if color is None:
90
+ color = "coolwarm" if plot_type == 'layered_violin' else "#ff0052"
91
+
92
+ # convert from a DataFrame or other types
93
+ if str(type(features)) == "<class 'pandas.core.frame.DataFrame'>":
94
+ if feature_names is None:
95
+ feature_names = features.columns
96
+ features = features.values
97
+ elif str(type(features)) == "<class 'list'>":
98
+ if feature_names is None:
99
+ feature_names = features
100
+ features = None
101
+ elif (features is not None) and len(features.shape) == 1 and feature_names is None:
102
+ feature_names = features
103
+ features = None
104
+
105
+ if feature_names is None:
106
+ feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
107
+
108
+ mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
109
+
110
+ # plotting SHAP interaction values
111
+ if len(shap_values.shape) == 3:
112
+ if max_display is None:
113
+ max_display = 7
114
+ else:
115
+ max_display = min(len(feature_names), max_display)
116
+
117
+ sort_inds = np.argsort(-np.abs(shap_values[:, :-1, :-1].sum(1)).sum(0))
118
+
119
+ # get plotting limits
120
+ delta = 1.0 / (shap_values.shape[1] ** 2)
121
+ slow = np.nanpercentile(shap_values, delta)
122
+ shigh = np.nanpercentile(shap_values, 100 - delta)
123
+ v = max(abs(slow), abs(shigh))
124
+ slow = -0.2
125
+ shigh = 0.2
126
+
127
+ # mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
128
+ ax = mpl_fig.subplot(1, max_display, 1)
129
+ proj_shap_values = shap_values[:, sort_inds[0], np.hstack((sort_inds, len(sort_inds)))]
130
+ proj_shap_values[:, 1:] *= 2 # because off diag effects are split in half
131
+ shap_summary_plot(
132
+ proj_shap_values, features[:, sort_inds],
133
+ feature_names=feature_names[sort_inds],
134
+ sort=False, show=False, color_bar=False,
135
+ auto_size_plot=False,
136
+ max_display=max_display
137
+ )
138
+ pl.xlim((slow, shigh))
139
+ pl.xlabel("")
140
+ title_length_limit = 11
141
+ pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit))
142
+ for i in range(1, max_display):
143
+ ind = sort_inds[i]
144
+ pl.subplot(1, max_display, i + 1)
145
+ proj_shap_values = shap_values[:, ind, np.hstack((sort_inds, len(sort_inds)))]
146
+ proj_shap_values *= 2
147
+ proj_shap_values[:, i] /= 2 # because only off diag effects are split in half
148
+ shap_summary_plot(
149
+ proj_shap_values, features[:, sort_inds],
150
+ sort=False,
151
+ feature_names=["" for i in range(features.shape[1])],
152
+ show=False,
153
+ color_bar=False,
154
+ auto_size_plot=False,
155
+ max_display=max_display
156
+ )
157
+ pl.xlim((slow, shigh))
158
+ pl.xlabel("")
159
+ if i == max_display // 2:
160
+ pl.xlabel(labels['INTERACTION_VALUE'])
161
+ pl.title(shorten_text(feature_names[ind], title_length_limit))
162
+ pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
163
+ pl.subplots_adjust(hspace=0, wspace=0.1)
164
+ # if show:
165
+ # # pl.show()
166
+ return mpl_fig
167
+
168
+ if max_display is None:
169
+ max_display = 20
170
+
171
+ if sort:
172
+ # order features by the sum of their effect magnitudes
173
+ feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
174
+ feature_order = feature_order[-min(max_display, len(feature_order)):]
175
+ else:
176
+ feature_order = np.flip(np.arange(min(max_display, shap_values.shape[1] - 1)), 0)
177
+
178
+ row_height = 0.4
179
+ if auto_size_plot:
180
+ pl.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5)
181
+ pl.axvline(x=0, color="#999999", zorder=-1)
182
+
183
+ if plot_type == "dot":
184
+ for pos, i in enumerate(feature_order):
185
+ pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
186
+ shaps = shap_values[:, i]
187
+ values = None if features is None else features[:, i]
188
+ inds = np.arange(len(shaps))
189
+ np.random.shuffle(inds)
190
+ if values is not None:
191
+ values = values[inds]
192
+ shaps = shaps[inds]
193
+ colored_feature = True
194
+ try:
195
+ values = np.array(values, dtype=np.float64) # make sure this can be numeric
196
+ except:
197
+ colored_feature = False
198
+ N = len(shaps)
199
+ # hspacing = (np.max(shaps) - np.min(shaps)) / 200
200
+ # curr_bin = []
201
+ nbins = 100
202
+ quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
203
+ inds = np.argsort(quant + np.random.randn(N) * 1e-6)
204
+ layer = 0
205
+ last_bin = -1
206
+ ys = np.zeros(N)
207
+ for ind in inds:
208
+ if quant[ind] != last_bin:
209
+ layer = 0
210
+ ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
211
+ layer += 1
212
+ last_bin = quant[ind]
213
+ ys *= 0.9 * (row_height / np.max(ys + 1))
214
+
215
+ if features is not None and colored_feature:
216
+ # trim the color range, but prevent the color range from collapsing
217
+ vmin = np.nanpercentile(values, 5)
218
+ vmax = np.nanpercentile(values, 95)
219
+ if vmin == vmax:
220
+ vmin = np.nanpercentile(values, 1)
221
+ vmax = np.nanpercentile(values, 99)
222
+ if vmin == vmax:
223
+ vmin = np.min(values)
224
+ vmax = np.max(values)
225
+
226
+ assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!"
227
+ nan_mask = np.isnan(values)
228
+ pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin,
229
+ vmax=vmax, s=16, alpha=alpha, linewidth=0,
230
+ zorder=3, rasterized=len(shaps) > 500)
231
+ pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
232
+ cmap=red_blue, vmin=vmin, vmax=vmax, s=16,
233
+ c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0,
234
+ zorder=3, rasterized=len(shaps) > 500)
235
+ else:
236
+
237
+ pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3,
238
+ color=color if colored_feature else "#777777", rasterized=len(shaps) > 500)
239
+
240
+ elif plot_type == "violin":
241
+ for pos, i in enumerate(feature_order):
242
+ pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
243
+
244
+ if features is not None:
245
+ global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1)
246
+ global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99)
247
+ for pos, i in enumerate(feature_order):
248
+ shaps = shap_values[:, i]
249
+ shap_min, shap_max = np.min(shaps), np.max(shaps)
250
+ rng = shap_max - shap_min
251
+ xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100)
252
+ if np.std(shaps) < (global_high - global_low) / 100:
253
+ ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs)
254
+ else:
255
+ ds = gaussian_kde(shaps)(xs)
256
+ ds /= np.max(ds) * 3
257
+
258
+ values = features[:, i]
259
+ window_size = max(10, len(values) // 20)
260
+ smooth_values = np.zeros(len(xs) - 1)
261
+ sort_inds = np.argsort(shaps)
262
+ trailing_pos = 0
263
+ leading_pos = 0
264
+ running_sum = 0
265
+ back_fill = 0
266
+ for j in range(len(xs) - 1):
267
+
268
+ while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]:
269
+ running_sum += values[sort_inds[leading_pos]]
270
+ leading_pos += 1
271
+ if leading_pos - trailing_pos > 20:
272
+ running_sum -= values[sort_inds[trailing_pos]]
273
+ trailing_pos += 1
274
+ if leading_pos - trailing_pos > 0:
275
+ smooth_values[j] = running_sum / (leading_pos - trailing_pos)
276
+ for k in range(back_fill):
277
+ smooth_values[j - k - 1] = smooth_values[j]
278
+ else:
279
+ back_fill += 1
280
+
281
+ vmin = np.nanpercentile(values, 5)
282
+ vmax = np.nanpercentile(values, 95)
283
+ if vmin == vmax:
284
+ vmin = np.nanpercentile(values, 1)
285
+ vmax = np.nanpercentile(values, 99)
286
+ if vmin == vmax:
287
+ vmin = np.min(values)
288
+ vmax = np.max(values)
289
+ pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=red_blue_solid, vmin=vmin, vmax=vmax,
290
+ c=values, alpha=alpha, linewidth=0, zorder=1)
291
+ # smooth_values -= nxp.nanpercentile(smooth_values, 5)
292
+ # smooth_values /= np.nanpercentile(smooth_values, 95)
293
+ smooth_values -= vmin
294
+ if vmax - vmin > 0:
295
+ smooth_values /= vmax - vmin
296
+ for i in range(len(xs) - 1):
297
+ if ds[i] > 0.05 or ds[i + 1] > 0.05:
298
+ pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]],
299
+ [pos - ds[i], pos - ds[i + 1]], color=red_blue_solid(smooth_values[i]),
300
+ zorder=2)
301
+
302
+ else:
303
+ parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False,
304
+ widths=0.7,
305
+ showmeans=False, showextrema=False, showmedians=False)
306
+
307
+ for pc in parts['bodies']:
308
+ pc.set_facecolor(color)
309
+ pc.set_edgecolor('none')
310
+ pc.set_alpha(alpha)
311
+
312
+ elif plot_type == "layered_violin": # courtesy of @kodonnell
313
+ num_x_points = 200
314
+ bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype(
315
+ 'int') # the indices of the feature data corresponding to each bin
316
+ shap_min, shap_max = np.min(shap_values[:, :-1]), np.max(shap_values[:, :-1])
317
+ x_points = np.linspace(shap_min, shap_max, num_x_points)
318
+
319
+ # loop through each feature and plot:
320
+ for pos, ind in enumerate(feature_order):
321
+ # decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles.
322
+ # to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts.
323
+ feature = features[:, ind]
324
+ unique, counts = np.unique(feature, return_counts=True)
325
+ if unique.shape[0] <= layered_violin_max_num_bins:
326
+ order = np.argsort(unique)
327
+ thesebins = np.cumsum(counts[order])
328
+ thesebins = np.insert(thesebins, 0, 0)
329
+ else:
330
+ thesebins = bins
331
+ nbins = thesebins.shape[0] - 1
332
+ # order the feature data so we can apply percentiling
333
+ order = np.argsort(feature)
334
+ # x axis is located at y0 = pos, with pos being there for offset
335
+ y0 = np.ones(num_x_points) * pos
336
+ # calculate kdes:
337
+ ys = np.zeros((nbins, num_x_points))
338
+ for i in range(nbins):
339
+ # get shap values in this bin:
340
+ shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind]
341
+ # if there's only one element, then we can't
342
+ if shaps.shape[0] == 1:
343
+ warnings.warn(
344
+ "not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot."
345
+ % (i, feature_names[ind]))
346
+ # to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's
347
+ # nothing to do if i == 0
348
+ if i > 0:
349
+ ys[i, :] = ys[i - 1, :]
350
+ continue
351
+ # save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors
352
+ ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points)
353
+ # scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will
354
+ # do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1%
355
+ # female, we want the 1% to appear a lot smaller.
356
+ size = thesebins[i + 1] - thesebins[i]
357
+ bin_size_if_even = features.shape[0] / nbins
358
+ relative_bin_size = size / bin_size_if_even
359
+ ys[i, :] *= relative_bin_size
360
+ # now plot 'em. We don't plot the individual strips, as this can leave whitespace between them.
361
+ # instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no
362
+ # whitespace
363
+ ys = np.cumsum(ys, axis=0)
364
+ width = 0.8
365
+ scale = ys.max() * 2 / width # 2 is here as we plot both sides of x axis
366
+ for i in range(nbins - 1, -1, -1):
367
+ y = ys[i, :] / scale
368
+ c = pl.get_cmap(color)(i / (
369
+ nbins - 1)) if color in pl.cm.datad else color # if color is a cmap, use it, otherwise use a color
370
+ pl.fill_between(x_points, pos - y, pos + y, facecolor=c)
371
+ pl.xlim(shap_min, shap_max)
372
+
373
+ # draw the color bar
374
+ if color_bar and features is not None and (plot_type != "layered_violin" or color in pl.cm.datad):
375
+ import matplotlib.cm as cm
376
+ m = cm.ScalarMappable(cmap=red_blue_solid if plot_type != "layered_violin" else pl.get_cmap(color))
377
+ m.set_array([0, 1])
378
+ cb = pl.colorbar(m, ticks=[0, 1], aspect=1000)
379
+ cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
380
+ cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0)
381
+ cb.ax.tick_params(labelsize=11, length=0)
382
+ cb.set_alpha(1)
383
+ cb.outline.set_visible(False)
384
+ bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
385
+ cb.ax.set_aspect((bbox.height - 0.9) * 20)
386
+ # cb.draw_all()
387
+
388
+ pl.gca().xaxis.set_ticks_position('bottom')
389
+ pl.gca().yaxis.set_ticks_position('none')
390
+ pl.gca().spines['right'].set_visible(False)
391
+ pl.gca().spines['top'].set_visible(False)
392
+ pl.gca().spines['left'].set_visible(False)
393
+ pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
394
+ pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13)
395
+ pl.gca().tick_params('y', length=20, width=0.5, which='major')
396
+ pl.gca().tick_params('x', labelsize=11)
397
+ pl.ylim(-1, len(feature_order))
398
+ pl.xlabel(labels['VALUE'], fontsize=13)
399
+ pl.tight_layout()
400
+ # if show:
401
+ # pl.show()
402
+ return mpl_fig
403
+
404
+
405
+
406
+
407
+
408
+
409
+ def approx_interactions(index, shap_values, X):
410
+ """ Order other features by how much interaction they seem to have with the feature at the given index.
411
+
412
+ This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction
413
+ index values for SHAP see the interaction_contribs option implemented in XGBoost.
414
+ """
415
+
416
+ if X.shape[0] > 10000:
417
+ a = np.arange(X.shape[0])
418
+ np.random.shuffle(a)
419
+ inds = a[:10000]
420
+ else:
421
+ inds = np.arange(X.shape[0])
422
+
423
+ x = X[inds, index]
424
+ srt = np.argsort(x)
425
+ shap_ref = shap_values[inds, index]
426
+ shap_ref = shap_ref[srt]
427
+ inc = max(min(int(len(x) / 10.0), 50), 1)
428
+ interactions = []
429
+ for i in range(X.shape[1]):
430
+ val_other = X[inds, i][srt].astype(np.float)
431
+ v = 0.0
432
+ if not (i == index or np.sum(np.abs(val_other)) < 1e-8):
433
+ for j in range(0, len(x), inc):
434
+ if np.std(val_other[j:j + inc]) > 0 and np.std(shap_ref[j:j + inc]) > 0:
435
+ v += abs(np.corrcoef(shap_ref[j:j + inc], val_other[j:j + inc])[0, 1])
436
+ interactions.append(v)
437
+
438
+ return np.argsort(-np.abs(interactions))
439
+
440
+
441
+
442
+
443
+
444
+
445
+
446
+ def shap_dependence_plot(ind, shap_values, features, feature_names=None, display_features=None,
447
+ interaction_index="auto", color="#1E88E5", axis_color="#333333",
448
+ dot_size=16, alpha=1, title=None, show=True):
449
+ """
450
+ Create a SHAP dependence plot, colored by an interaction feature.
451
+
452
+ Parameters
453
+ ----------
454
+ ind : int
455
+ Index of the feature to plot.
456
+
457
+ shap_values : numpy.array
458
+ Matrix of SHAP values (# samples x # features)
459
+
460
+ features : numpy.array or pandas.DataFrame
461
+ Matrix of feature values (# samples x # features)
462
+
463
+ feature_names : list
464
+ Names of the features (length # features)
465
+
466
+ display_features : numpy.array or pandas.DataFrame
467
+ Matrix of feature values for visual display (such as strings instead of coded values)
468
+
469
+ interaction_index : "auto", None, or int
470
+ The index of the feature used to color the plot.
471
+ """
472
+
473
+ # convert from DataFrames if we got any
474
+ if str(type(features)).endswith("'pandas.core.frame.DataFrame'>"):
475
+ if feature_names is None:
476
+ feature_names = features.columns
477
+ features = features.values
478
+ if str(type(display_features)).endswith("'pandas.core.frame.DataFrame'>"):
479
+ if feature_names is None:
480
+ feature_names = display_features.columns
481
+ display_features = display_features.values
482
+ elif display_features is None:
483
+ display_features = features
484
+
485
+ if feature_names is None:
486
+ feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
487
+
488
+ # allow vectors to be passed
489
+ if len(shap_values.shape) == 1:
490
+ shap_values = np.reshape(shap_values, len(shap_values), 1)
491
+ if len(features.shape) == 1:
492
+ features = np.reshape(features, len(features), 1)
493
+
494
+ def convert_name(ind):
495
+ if type(ind) == str:
496
+ nzinds = np.where(feature_names == ind)[0]
497
+ if len(nzinds) == 0:
498
+ print("Could not find feature named: " + ind)
499
+ return None
500
+ else:
501
+ return nzinds[0]
502
+ else:
503
+ return ind
504
+
505
+ ind = convert_name(ind)
506
+
507
+ mpl_fig = pl.gcf()
508
+ ax = mpl_fig.gca()
509
+
510
+ # plotting SHAP interaction values
511
+ if len(shap_values.shape) == 3 and len(ind) == 2:
512
+ ind1 = convert_name(ind[0])
513
+ ind2 = convert_name(ind[1])
514
+ if ind1 == ind2:
515
+ proj_shap_values = shap_values[:, ind2, :]
516
+ else:
517
+ proj_shap_values = shap_values[:, ind2, :] * 2 # off-diag values are split in half
518
+
519
+ # TODO: remove recursion; generally the functions should be shorter for more maintainable code
520
+ return shap_dependence_plot(
521
+ ind1, proj_shap_values, features, feature_names=feature_names,
522
+ interaction_index=ind2, display_features=display_features, show=False
523
+ )
524
+
525
+ assert shap_values.shape[0] == features.shape[0], \
526
+ "'shap_values' and 'features' values must have the same number of rows!"
527
+ assert shap_values.shape[1] == features.shape[1], \
528
+ "'shap_values' must have the same number of columns as 'features'!"
529
+
530
+ # get both the raw and display feature values
531
+ xv = features[:, ind]
532
+ xd = display_features[:, ind]
533
+ s = shap_values[:, ind]
534
+ if type(xd[0]) == str:
535
+ name_map = {}
536
+ for i in range(len(xv)):
537
+ name_map[xd[i]] = xv[i]
538
+ xnames = list(name_map.keys())
539
+
540
+ # allow a single feature name to be passed alone
541
+ if type(feature_names) == str:
542
+ feature_names = [feature_names]
543
+ name = feature_names[ind]
544
+
545
+ # guess what other feature as the stongest interaction with the plotted feature
546
+ if interaction_index == "auto":
547
+ interaction_index = approx_interactions(ind, shap_values, features)[0]
548
+ interaction_index = convert_name(interaction_index)
549
+ categorical_interaction = False
550
+
551
+ # get both the raw and display color values
552
+ if interaction_index is not None:
553
+ cv = features[:, interaction_index]
554
+ cd = display_features[:, interaction_index]
555
+ clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
556
+ chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
557
+ if type(cd[0]) == str:
558
+ cname_map = {}
559
+ for i in range(len(cv)):
560
+ cname_map[cd[i]] = cv[i]
561
+ cnames = list(cname_map.keys())
562
+ categorical_interaction = True
563
+ elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
564
+ categorical_interaction = True
565
+
566
+ # discritize colors for categorical features
567
+ color_norm = None
568
+ if categorical_interaction and clow != chigh:
569
+ bounds = np.linspace(clow, chigh, chigh - clow + 2)
570
+ color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
571
+
572
+ # the actual scatter plot, TODO: adapt the dot_size to the number of data points?
573
+ if interaction_index is not None:
574
+ pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
575
+ alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
576
+ else:
577
+ pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
578
+ alpha=alpha, rasterized=len(xv) > 500)
579
+
580
+ if interaction_index != ind and interaction_index is not None:
581
+ # draw the color bar
582
+ if type(cd[0]) == str:
583
+ tick_positions = [cname_map[n] for n in cnames]
584
+ if len(tick_positions) == 2:
585
+ tick_positions[0] -= 0.25
586
+ tick_positions[1] += 0.25
587
+ cb = pl.colorbar(ticks=tick_positions)
588
+ cb.set_ticklabels(cnames)
589
+ else:
590
+ cb = pl.colorbar()
591
+
592
+ cb.set_label(feature_names[interaction_index], size=13)
593
+ cb.ax.tick_params(labelsize=11)
594
+ if categorical_interaction:
595
+ cb.ax.tick_params(length=0)
596
+ cb.set_alpha(1)
597
+ cb.outline.set_visible(False)
598
+ bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
599
+ cb.ax.set_aspect((bbox.height - 0.7) * 20)
600
+
601
+ # make the plot more readable
602
+ if interaction_index != ind:
603
+ pl.gcf().set_size_inches(7.5, 5)
604
+ else:
605
+ pl.gcf().set_size_inches(6, 5)
606
+ # pl.xlabel(name, color=axis_color, fontsize=13)
607
+ # pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
608
+ if title is not None:
609
+ pl.title(title, color=axis_color, fontsize=13)
610
+ pl.gca().xaxis.set_ticks_position('bottom')
611
+ pl.gca().yaxis.set_ticks_position('left')
612
+ pl.gca().spines['right'].set_visible(False)
613
+ pl.gca().spines['top'].set_visible(False)
614
+ pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
615
+ for spine in pl.gca().spines.values():
616
+ spine.set_edgecolor(axis_color)
617
+ if type(xd[0]) == str:
618
+ pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
619
+ # if show:
620
+ # pl.show()
621
+
622
+
623
+ if ind1 == ind2:
624
+ pl.ylabel(labels['MAIN_EFFECT'] % feature_names[ind1])
625
+ else:
626
+ pl.ylabel(labels['INTERACTION_EFFECT'] % (feature_names[ind1], feature_names[ind2]))
627
+
628
+ return mpl_fig, interaction_index
629
+
630
+
631
+ # # if show:
632
+ # # pl.show()
633
+ # return
634
+ # return mpl_fig
635
+
636
+ # assert shap_values.shape[0] == features.shape[0], "'shap_values' and 'features' values must have the same number of rows!"
637
+ # assert shap_values.shape[1] == features.shape[1] + 1, "'shap_values' must have one more column than 'features'!"
638
+
639
+ # get both the raw and display feature values
640
+ xv = features[:, ind]
641
+ xd = display_features[:, ind]
642
+ s = shap_values[:, ind]
643
+ if type(xd[0]) == str:
644
+ name_map = {}
645
+ for i in range(len(xv)):
646
+ name_map[xd[i]] = xv[i]
647
+ xnames = list(name_map.keys())
648
+
649
+ # allow a single feature name to be passed alone
650
+ if type(feature_names) == str:
651
+ feature_names = [feature_names]
652
+ name = feature_names[ind]
653
+
654
+ # guess what other feature as the stongest interaction with the plotted feature
655
+ if interaction_index == "auto":
656
+ interaction_index = approx_interactions(ind, shap_values, features)[0]
657
+ interaction_index = convert_name(interaction_index)
658
+ categorical_interaction = False
659
+
660
+ # get both the raw and display color values
661
+ if interaction_index is not None:
662
+ cv = features[:, interaction_index]
663
+ cd = display_features[:, interaction_index]
664
+ clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
665
+ chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
666
+ if type(cd[0]) == str:
667
+ cname_map = {}
668
+ for i in range(len(cv)):
669
+ cname_map[cd[i]] = cv[i]
670
+ cnames = list(cname_map.keys())
671
+ categorical_interaction = True
672
+ elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
673
+ categorical_interaction = True
674
+
675
+ # discritize colors for categorical features
676
+ color_norm = None
677
+ if categorical_interaction and clow != chigh:
678
+ bounds = np.linspace(clow, chigh, chigh - clow + 2)
679
+ color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
680
+
681
+ # the actual scatter plot, TODO: adapt the dot_size to the number of data points?
682
+ if interaction_index is not None:
683
+ pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
684
+ alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
685
+ else:
686
+ pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
687
+ alpha=alpha, rasterized=len(xv) > 500)
688
+
689
+ if interaction_index != ind and interaction_index is not None:
690
+ # draw the color bar
691
+ if type(cd[0]) == str:
692
+ tick_positions = [cname_map[n] for n in cnames]
693
+ if len(tick_positions) == 2:
694
+ tick_positions[0] -= 0.25
695
+ tick_positions[1] += 0.25
696
+ cb = pl.colorbar(ticks=tick_positions)
697
+ cb.set_ticklabels(cnames)
698
+ else:
699
+ cb = pl.colorbar()
700
+
701
+ cb.set_label(feature_names[interaction_index], size=13)
702
+ cb.ax.tick_params(labelsize=11)
703
+ if categorical_interaction:
704
+ cb.ax.tick_params(length=0)
705
+ cb.set_alpha(1)
706
+ cb.outline.set_visible(False)
707
+ bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
708
+ cb.ax.set_aspect((bbox.height - 0.7) * 20)
709
+
710
+ # make the plot more readable
711
+ if interaction_index != ind:
712
+ pl.gcf().set_size_inches(7.5, 5)
713
+ else:
714
+ pl.gcf().set_size_inches(6, 5)
715
+ pl.xlabel(name, color=axis_color, fontsize=13)
716
+ pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
717
+ if title is not None:
718
+ pl.title(title, color=axis_color, fontsize=13)
719
+ pl.gca().xaxis.set_ticks_position('bottom')
720
+ pl.gca().yaxis.set_ticks_position('left')
721
+ pl.gca().spines['right'].set_visible(False)
722
+ pl.gca().spines['top'].set_visible(False)
723
+ pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
724
+ for spine in pl.gca().spines.values():
725
+ spine.set_edgecolor(axis_color)
726
+ if type(xd[0]) == str:
727
+ pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
728
+ # if show:
729
+ # pl.show()
730
  return mpl_fig, interaction_index