hlnicholls commited on
Commit
f8a37b1
1 Parent(s): 8a6cf88

feat: supervised clustering

Browse files
__pycache__/dynamic_shap_plots.cpython-38.pyc CHANGED
Binary files a/__pycache__/dynamic_shap_plots.cpython-38.pyc and b/__pycache__/dynamic_shap_plots.cpython-38.pyc differ
 
__pycache__/shap_plots.cpython-38.pyc CHANGED
Binary files a/__pycache__/shap_plots.cpython-38.pyc and b/__pycache__/shap_plots.cpython-38.pyc differ
 
app.py CHANGED
@@ -6,13 +6,15 @@ import pickle
6
  import sklearn
7
  import catboost
8
  import shap
9
- from shap_plots import shap_summary_plot
10
- from dynamic_shap_plots import matplotlib_to_plotly, summary_plot_plotly_fig
11
  import plotly.tools as tls
12
  from dash import dcc
 
 
 
 
13
  import matplotlib.pyplot as plt
14
  import plotly.graph_objs as go
15
-
16
  try:
17
  import matplotlib.pyplot as pl
18
  from matplotlib.colors import LinearSegmentedColormap
@@ -139,6 +141,10 @@ if tab == "Gene Prioritisation":
139
  else:
140
  pass
141
 
 
 
 
 
142
  st.markdown("""
143
  ### Total Gene Prioritisation Results for All Genes:
144
  """)
@@ -150,7 +156,6 @@ if tab == "Gene Prioritisation":
150
  csv = convert_df(df_total_output)
151
  st.download_button("Download Gene Prioritisation", csv, "all_genes_bp_prioritisation.csv", "text/csv", key='download-all-csv')
152
 
153
- # Page 2: Interactive SHAP Plot
154
 
155
  elif tab == "Interactive SHAP Plot":
156
  st.title("Interactive SHAP Plot")
@@ -170,15 +175,124 @@ elif tab == "Interactive SHAP Plot":
170
  df_shap = df.drop(columns=probability_columns + ['Gene'])
171
  shap_values = explainer.shap_values(df_shap)
172
 
173
- # Use shap's summary_plot function for interactivity
174
- # summary_plot = shap.summary_plot(shap_values[0], df_shap, plot_type='interactive', max_display=10)
175
- summary_plot = summary_plot_plotly_fig(df_shap, shap_values[0], max_display=10)
176
- st.pyplot(summary_plot)
177
- st.caption("SHAP Summary Plot of All Input Genes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
 
 
179
 
180
- # Page 3: Supervised SHAP Clustering
181
  elif tab == "Supervised SHAP Clustering":
182
  st.title("Supervised SHAP Clustering")
183
- # Add your code here to implement supervised SHAP clustering
184
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import sklearn
7
  import catboost
8
  import shap
 
 
9
  import plotly.tools as tls
10
  from dash import dcc
11
+ from sklearn.cluster import KMeans
12
+ from sklearn.decomposition import PCA
13
+ from sklearn.metrics import silhouette_score
14
+ import plotly.express as px
15
  import matplotlib.pyplot as plt
16
  import plotly.graph_objs as go
17
+ import plotly.graph_objects as go
18
  try:
19
  import matplotlib.pyplot as pl
20
  from matplotlib.colors import LinearSegmentedColormap
 
141
  else:
142
  pass
143
 
144
+ url = f"https://astrazeneca-cgr-publications.github.io/DrugnomeAI/geneview.html?gene={input_gene}"
145
+ markdown_link = f"[{input_gene} druggability in DrugnomeAI]({url})"
146
+ st.markdown(markdown_link, unsafe_allow_html=True)
147
+
148
  st.markdown("""
149
  ### Total Gene Prioritisation Results for All Genes:
150
  """)
 
156
  csv = convert_df(df_total_output)
157
  st.download_button("Download Gene Prioritisation", csv, "all_genes_bp_prioritisation.csv", "text/csv", key='download-all-csv')
158
 
 
159
 
160
  elif tab == "Interactive SHAP Plot":
161
  st.title("Interactive SHAP Plot")
 
175
  df_shap = df.drop(columns=probability_columns + ['Gene'])
176
  shap_values = explainer.shap_values(df_shap)
177
 
178
+ shap_values_first_class = shap_values[0]
179
+ feature_importance = np.abs(shap_values_first_class).mean(axis=0)
180
+ top_features_indices = np.argsort(feature_importance)[-20:]
181
+ features_top = df_shap.columns[top_features_indices][::-1]
182
+ shap_values_top = shap_values_first_class[:, top_features_indices][..., ::-1]
183
+
184
+ # Prepare data for a single trace
185
+ x_values = []
186
+ y_values = []
187
+ hover_texts = []
188
+ for i, feature_name in enumerate(features_top):
189
+ for gene, value in zip(df['Gene'], shap_values_top[:, i]):
190
+ x_values.append(value)
191
+ y_values.append(feature_name)
192
+ hover_texts.append(f'{gene}: {value:.3f}')
193
+
194
+ # Create a single trace for the plot
195
+ fig = go.Figure(data=go.Scatter(
196
+ x=x_values,
197
+ y=y_values,
198
+ mode='markers',
199
+ marker=dict(
200
+ color=x_values, # Set color to SHAP values
201
+ colorbar=dict(title="SHAP Value"),
202
+ colorscale=[(0, "blue"), (1, "red")], # Blue to Red color scale
203
+ ),
204
+ text=hover_texts, # Set hover text
205
+ hoverinfo="text+x" # Display hover text and x-value (SHAP value)
206
+ ))
207
+
208
+ fig.update_layout(
209
+ title="SHAP Summary Plot - Top 20 Features",
210
+ xaxis_title="SHAP Value",
211
+ yaxis=dict(autorange="reversed", title="Feature"),
212
+ showlegend=False,
213
+ )
214
 
215
+ st.plotly_chart(fig, use_container_width=True)
216
+ st.caption("SHAP Summary Plot of All Input Genes - Top 20 Features")
217
 
 
218
  elif tab == "Supervised SHAP Clustering":
219
  st.title("Supervised SHAP Clustering")
220
+ training_genes = pd.read_csv("training_cleaned.csv")
221
+ training_genes = training_genes[training_genes['BPlabel_encoded'] == 0]
222
+ training_genes.set_index('Gene', inplace=True)
223
+
224
+ # Calculate SHAP values for the full dataset
225
+ shap_values_full = explainer.shap_values(annotations)
226
+ shap_values_full_array = np.array(shap_values_full[0])
227
+
228
+ # Apply PCA to reduce dimensionality for visualization
229
+ pca = PCA(n_components=2)
230
+ shap_values_pca = pca.fit_transform(shap_values_full_array)
231
+
232
+ # Apply clustering on the PCA-reduced SHAP values
233
+ kmeans = KMeans(n_clusters=3, random_state=0).fit(shap_values_pca)
234
+
235
+ # Get cluster labels for each point in the dataset
236
+ labels = kmeans.labels_
237
+
238
+ # Prepare a DataFrame for visualization
239
+ df_for_plot = pd.DataFrame({
240
+ 'PCA_1': shap_values_pca[:, 0],
241
+ 'PCA_2': shap_values_pca[:, 1],
242
+ 'Cluster': labels.astype(str),
243
+ 'Gene': annotations.index,
244
+ 'Type': 'Clustered Gene'
245
+ })
246
+
247
+ # Add a new column for marking the special groups
248
+ df_for_plot['SpecialGroup'] = 'None'
249
+ df_for_plot.loc[df_for_plot['Gene'].isin(training_genes.index), 'SpecialGroup'] = 'Most Likely Training Gene'
250
+ if gene_list:
251
+ df_for_plot.loc[df_for_plot['Gene'].isin(gene_list), 'SpecialGroup'] = 'User Input Gene'
252
+
253
+ # Initialize an empty figure
254
+ fig = go.Figure()
255
+
256
+ # Plot clustered genes based on PCA components
257
+ for cluster in df_for_plot['Cluster'].unique():
258
+ filtered_df = df_for_plot[(df_for_plot['Cluster'] == cluster) & (df_for_plot['SpecialGroup'] == 'None')]
259
+ fig.add_trace(go.Scatter(
260
+ x=filtered_df['PCA_1'], y=filtered_df['PCA_2'],
261
+ mode='markers',
262
+ name=f'Cluster {cluster}',
263
+ text=filtered_df['Gene'],
264
+ hoverinfo="text+x+y",
265
+ ))
266
+
267
+ # Overlay "Most Likely Training Gene"
268
+ filtered_df = df_for_plot[df_for_plot['SpecialGroup'] == 'Most Likely Training Gene']
269
+ fig.add_trace(go.Scatter(
270
+ x=filtered_df['PCA_1'], y=filtered_df['PCA_2'],
271
+ mode='markers',
272
+ name='Most Likely Training Gene',
273
+ text=filtered_df['Gene'],
274
+ marker=dict(color='rgba(255, 0, 0, .9)'),
275
+ hoverinfo="text+x+y",
276
+ ))
277
+
278
+ # Overlay "User Input Gene"
279
+ filtered_df = df_for_plot[df_for_plot['SpecialGroup'] == 'User Input Gene']
280
+ fig.add_trace(go.Scatter(
281
+ x=filtered_df['PCA_1'], y=filtered_df['PCA_2'],
282
+ mode='markers',
283
+ name='User Input Gene',
284
+ text=filtered_df['Gene'],
285
+ marker=dict(color='rgba(0, 255, 0, .9)'),
286
+ hoverinfo="text+x+y",
287
+ ))
288
+
289
+ # Customize layout
290
+ fig.update_layout(
291
+ title='Supervised SHAP Clustering with PCA',
292
+ xaxis_title='First Principal Component',
293
+ yaxis_title='Second Principal Component',
294
+ showlegend=True,
295
+ legend_title_text='Gene Category',
296
+ )
297
+
298
+ st.plotly_chart(fig, use_container_width=True)
dynamic_shap_plots.py DELETED
@@ -1,346 +0,0 @@
1
- from shap_plots import shap_summary_plot, shap_dependence_plot
2
- import plotly.tools as tls
3
- import dash_core_components as dcc
4
- import pandas as pd
5
- from sklearn.model_selection import train_test_split
6
- import numpy as np
7
- import xgboost
8
- import shap
9
- import matplotlib
10
- import plotly.graph_objs as go
11
- try:
12
- import matplotlib.pyplot as pl
13
- from matplotlib.colors import LinearSegmentedColormap
14
- from matplotlib.ticker import MaxNLocator
15
- except ImportError:
16
- pass
17
- from sklearn import preprocessing
18
-
19
- cdict1 = {
20
- 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
21
- (1.0, 0.9607843137254902, 0.9607843137254902)),
22
-
23
- 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
24
- (1.0, 0.15294117647058825, 0.15294117647058825)),
25
-
26
- 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
27
- (1.0, 0.3411764705882353, 0.3411764705882353)),
28
-
29
- 'alpha': ((0.0, 1, 1),
30
- (0.5, 1, 1),
31
- (1.0, 1, 1))
32
- } # #1E88E5 -> #ff0052
33
- red_blue = LinearSegmentedColormap('RedBlue', cdict1)
34
-
35
- def matplotlib_to_plotly(cmap, pl_entries):
36
- h = 1.0/(pl_entries-1)
37
- pl_colorscale = []
38
-
39
- for k in range(pl_entries):
40
- C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255))
41
- pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
42
-
43
- return pl_colorscale
44
-
45
- red_blue = matplotlib_to_plotly(red_blue, 255)
46
-
47
- def summary_plot_plotly_fig(dataset, shap_values, target='target column', max_display = 20):
48
- feature_names=dataset.columns
49
- mpl_fig = shap_summary_plot(shap_values, dataset, feature_names=feature_names, max_display=20)
50
-
51
- plotly_fig = tls.mpl_to_plotly(mpl_fig)
52
-
53
- plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}
54
-
55
- feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
56
- feature_order = feature_order[-min(max_display, len(feature_order)):]
57
- text = [feature_names[i] for i in feature_order]
58
- text = iter(text)
59
-
60
- for i in range(1, len(plotly_fig['data']), 2):
61
- t = text.__next__()
62
- plotly_fig['data'][i]['name'] = ''
63
- plotly_fig['data'][i]['text'] = t
64
- plotly_fig['data'][i]['hoverinfo'] = 'text'
65
-
66
- colorbar_trace = go.Scatter(x=[None],
67
- y=[None],
68
- mode='markers',
69
- marker=dict(
70
- colorscale=red_blue,
71
- showscale=True,
72
- cmin=-5,
73
- cmax=5,
74
- colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0)
75
- ),
76
- hoverinfo='none'
77
- )
78
-
79
- plotly_fig['layout']['showlegend'] = False
80
- plotly_fig['layout']['hovermode'] = 'closest'
81
- plotly_fig['layout']['height']=600
82
- plotly_fig['layout']['width']=500
83
-
84
- plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False)
85
- plotly_fig['layout']['yaxis'].update(dict(visible=False))
86
- plotly_fig.add_trace(colorbar_trace)
87
- plotly_fig.layout.update(
88
- annotations=[dict(
89
- x=1.18,
90
- align="right",
91
- valign="top",
92
- text='Feature value',
93
- showarrow=False,
94
- xref="paper",
95
- yref="paper",
96
- xanchor="right",
97
- yanchor="middle",
98
- textangle=-90,
99
- font=dict(family='Calibri', size=14)
100
- )
101
- ],
102
- margin=dict(t=20)
103
- )
104
- return plotly_fig
105
-
106
- def train_model_and_return_shap_values(X, y, target):
107
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
108
-
109
- X_train.fillna((-999), inplace=True)
110
- X_test.fillna((-999), inplace=True)
111
-
112
- # Some of values are float or integer and some object. This is why we need to cast them:
113
- for f in X_train.columns:
114
- if X_train[f].dtype=='object':
115
- lbl = preprocessing.LabelEncoder()
116
- lbl.fit(list(X_train[f].values))
117
- X_train[f] = lbl.transform(list(X_train[f].values))
118
-
119
- for f in X_test.columns:
120
- if X_test[f].dtype=='object':
121
- lbl = preprocessing.LabelEncoder()
122
- lbl.fit(list(X_test[f].values))
123
- X_test[f] = lbl.transform(list(X_test[f].values))
124
-
125
- X_train=np.array(X_train)
126
- X_test=np.array(X_test)
127
- X_train = X_train.astype(float)
128
- X_test = X_test.astype(float)
129
-
130
- d_train = xgboost.DMatrix(X_train, label=y_train, feature_names=list(X))
131
- d_test = xgboost.DMatrix(X_test, label=y_test, feature_names=list(X))
132
-
133
- # train the model
134
- params = {
135
- "eta": 0.01,
136
- "subsample": 0.5,
137
- "base_score": np.mean(y_train),
138
- "silent": 1
139
- }
140
-
141
- model = xgboost.train(params, d_train, 5000, evals = [(d_test, "test")], verbose_eval=None, early_stopping_rounds=50)
142
- feature_names = model.feature_names
143
- shap_values = shap.TreeExplainer(model).shap_values(pd.DataFrame(X_train, columns=X.columns))
144
- return model, shap_values, feature_names
145
-
146
- def dependence_plot_to_plotly_fig(dataset, target='target column', max_display=10):
147
- data = pd.read_csv(dataset, encoding="ISO-8859-1")
148
- X = data.drop(['target column'], axis=1)
149
- y = data[target]
150
- y = y/max(y)
151
-
152
- xgb_full = xgboost.DMatrix(X, label=y)
153
-
154
- # create a train/test split
155
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
156
- xgb_train = xgboost.DMatrix(X_train, label=y_train)
157
- xgb_test = xgboost.DMatrix(X_test, label=y_test)
158
-
159
- # use validation set to choose # of trees
160
- params = {
161
- # "eta": 0.002,
162
- # "max_depth": 3,
163
- # "subsample": 0.5,
164
- "silent": 1
165
- }
166
- model_train = xgboost.train(params, xgb_train, 3000, evals = [(xgb_test, "test")], verbose_eval=None)
167
-
168
- # train final model on the full data set
169
- params = {
170
- # "eta": 0.002,
171
- # "max_depth": 3,
172
- # "subsample": 0.5,
173
- "silent": 1
174
- }
175
- model = xgboost.train(params, xgb_full, 1500, evals = [(xgb_full, "test")], verbose_eval=None)
176
- features = model.feature_names
177
- shap_values = shap.TreeExplainer(model).shap_values(X)
178
-
179
- feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
180
- feature_order = feature_order[-min(max_display, len(feature_order)):]
181
- features = [features[i] for i in feature_order[::-1]]
182
-
183
- lis = []
184
- for i in features:
185
- mpl_fig, interaction_index = shap_dependence_plot(i, shap_values, X)
186
- plotly_fig = tls.mpl_to_plotly(mpl_fig)
187
-
188
- # The x-tick labels start by default from 0, which is not necessarily the min value of the feature.
189
- # So, we need to increment the x-tick labels by 1. But while doing so, the y-axis gets shifted.
190
- # To prevent that, we need to manually control the x-axis range from r_min to r_max
191
- new_x = []
192
- for j in plotly_fig['data'][0]['x']:
193
- new_x.append(j)
194
-
195
- r_min = min(plotly_fig['data'][0]['x'])
196
- r_max = max(plotly_fig['data'][0]['x'])
197
-
198
- plotly_fig['layout']['xaxis'].update(range=[r_min-1, r_max+1])
199
- plotly_fig['data'][0]['x'] = tuple(new_x)
200
-
201
- # Define the colorbar
202
- colorbar_trace = go.Scatter(x=[None],
203
- y=[None],
204
- mode='markers',
205
- marker=dict(
206
- colorscale=red_blue,
207
- showscale=True,
208
- colorbar=dict(thickness=5, outlinewidth=0),
209
- color=[min(X[X.columns[interaction_index]]), max(X[X.columns[interaction_index]])],
210
- ),
211
- hoverinfo='none'
212
- )
213
-
214
- plotly_fig['layout']['showlegend'] = False
215
- plotly_fig['layout']['hovermode'] = 'closest'
216
- plotly_fig['layout']['height']=380
217
- plotly_fig['layout']['width']=450
218
- plotly_fig['layout']['xaxis'].update(zeroline=True,
219
- showline=True,
220
- ticklen=4,
221
- showgrid=False,
222
- tickmode='linear')
223
- title = plotly_fig['layout']['yaxis']['title']
224
- plotly_fig['layout']['yaxis'].update(title=title.split(' -')[0])
225
-
226
- plotly_fig.add_trace(colorbar_trace)
227
- plotly_fig.layout.update(
228
- annotations=[dict(
229
- x=1.23,
230
- align="right",
231
- valign="top",
232
- text=X.columns[interaction_index],
233
- showarrow=False,
234
- xref="paper",
235
- yref="paper",
236
- xanchor="right",
237
- yanchor="middle",
238
- textangle=-90,
239
- font=dict(family='Calibri', size=14)
240
- )
241
- ],
242
- margin=dict(t=50, b=50, l=50, r=80)
243
- )
244
- lis.append(plotly_fig)
245
- return lis, features
246
-
247
- def interaction_plot_to_plotly_fig(dataset, target_col='target column', max_display=10):
248
- data = pd.read_csv(dataset, encoding="ISO-8859-1")
249
- X = data.drop(['target column'], axis=1)
250
- y = data[target_col]
251
- y = y/max(y)
252
-
253
- xgb_full = xgboost.DMatrix(X, label=y)
254
-
255
- # create a train/test split
256
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
257
- xgb_train = xgboost.DMatrix(X_train, label=y_train)
258
- xgb_test = xgboost.DMatrix(X_test, label=y_test)
259
-
260
- # use validation set to choose # of trees
261
- params = {
262
- # "eta": 0.002,
263
- # "max_depth": 3,
264
- # "subsample": 0.5,
265
- "silent": 1
266
- }
267
- model_train = xgboost.train(params, xgb_train, 3000, evals = [(xgb_test, "test")], verbose_eval=None)
268
-
269
- # train final model on the full data set
270
- params = {
271
- # "eta": 0.002,
272
- # "max_depth": 3,
273
- # "subsample": 0.5,
274
- "silent": 1
275
- }
276
- model = xgboost.train(params, xgb_full, 1500, evals = [(xgb_full, "test")], verbose_eval=None)
277
- features = model.feature_names
278
- shap_values = shap.TreeExplainer(model).shap_values(X)
279
-
280
- feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
281
- feature_order = feature_order[-min(max_display, len(feature_order)):]
282
- features = [features[i] for i in feature_order[::-1]]
283
-
284
- shap_interaction_values = shap.TreeExplainer(model).shap_interaction_values(X)
285
-
286
- lis = []
287
- for i in features:
288
- for j in features:
289
- mpl_fig = pl.figure()
290
- ax = mpl_fig.add_subplot(111)
291
- _, interaction_index = shap_dependence_plot ( (i, j), shap_interaction_values, X.iloc[:2000,:] )
292
- plotly_fig = tls.mpl_to_plotly(mpl_fig)
293
-
294
- r_min = min(plotly_fig['data'][0]['x'])
295
- r_max = max(plotly_fig['data'][0]['x'])
296
-
297
- plotly_fig['layout']['xaxis'].update(range=[r_min-1, r_max+1])
298
- plotly_fig['layout']['showlegend'] = False
299
- plotly_fig['layout']['hovermode'] = 'closest'
300
- plotly_fig['layout']['height']=380
301
- plotly_fig['layout']['width']=450
302
- plotly_fig['layout']['xaxis'].update(zeroline=True,
303
- showline=True,
304
- ticklen=4,
305
- showgrid=False,
306
- tickmode='linear')
307
- plotly_fig['layout']['yaxis'].update(showline=True)
308
-
309
- if i!=j:
310
- # plotly_fig['layout']['height']=380
311
- plotly_fig['layout']['width']=480
312
- plotly_fig['layout']['yaxis']['title'] = "SHAP interaction value for {} and {}".format(i.split('-')[0], j.split('-')[0])
313
- # Define the colorbar
314
- colorbar_trace = go.Scatter(x=[None],
315
- y=[None],
316
- mode='markers',
317
- marker=dict(
318
- colorscale=red_blue,
319
- showscale=True,
320
- colorbar=dict(thickness=5, outlinewidth=0),
321
- color=[min(X[X.columns[interaction_index]]), max(X[X.columns[interaction_index]])],
322
- ),
323
- hoverinfo='none'
324
- )
325
- plotly_fig.add_trace(colorbar_trace)
326
- plotly_fig.layout.update(
327
- annotations=[dict(
328
- x=1.23,
329
- align="right",
330
- valign="top",
331
- text=X.columns[interaction_index],
332
- showarrow=False,
333
- xref="paper",
334
- yref="paper",
335
- xanchor="right",
336
- yanchor="middle",
337
- textangle=-90,
338
- font=dict(family='Calibri', size=14)
339
- )
340
- ],
341
- margin=dict(t=30, b=30, l=60, r=80)
342
- )
343
- else:
344
- plotly_fig['layout']['yaxis']['title'] = "SHAP main effect value for {}".format(i.split('-')[0])
345
- lis.append(plotly_fig)
346
- return lis, features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shap_plots.py DELETED
@@ -1,730 +0,0 @@
1
- import warnings
2
- import iml
3
- import numpy as np
4
- from iml import Instance, Model
5
- from iml.datatypes import DenseData
6
- from iml.explanations import AdditiveExplanation
7
- from iml.links import IdentityLink
8
- from scipy.stats import gaussian_kde
9
- import matplotlib
10
- try:
11
- import matplotlib.pyplot as pl
12
- from matplotlib.colors import LinearSegmentedColormap
13
- from matplotlib.ticker import MaxNLocator
14
-
15
- cdict1 = {
16
- 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
17
- (1.0, 0.9607843137254902, 0.9607843137254902)),
18
-
19
- 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
20
- (1.0, 0.15294117647058825, 0.15294117647058825)),
21
-
22
- 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
23
- (1.0, 0.3411764705882353, 0.3411764705882353)),
24
-
25
- 'alpha': ((0.0, 1, 1),
26
- (0.5, 0.3, 0.3),
27
- (1.0, 1, 1))
28
- } # #1E88E5 -> #ff0052
29
- red_blue = LinearSegmentedColormap('RedBlue', cdict1)
30
-
31
- cdict1 = {
32
- 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
33
- (1.0, 0.9607843137254902, 0.9607843137254902)),
34
-
35
- 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
36
- (1.0, 0.15294117647058825, 0.15294117647058825)),
37
-
38
- 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
39
- (1.0, 0.3411764705882353, 0.3411764705882353)),
40
-
41
- 'alpha': ((0.0, 1, 1),
42
- (0.5, 1, 1),
43
- (1.0, 1, 1))
44
- } # #1E88E5 -> #ff0052
45
- red_blue_solid = LinearSegmentedColormap('RedBlue', cdict1)
46
- except ImportError:
47
- pass
48
-
49
- labels = {
50
- 'MAIN_EFFECT': "SHAP main effect value for\n%s",
51
- 'INTERACTION_VALUE': "SHAP interaction value",
52
- 'INTERACTION_EFFECT': "SHAP interaction value for\n%s and %s",
53
- 'VALUE': "SHAP value (impact on model output)",
54
- 'VALUE_FOR': "SHAP value for\n%s",
55
- 'PLOT_FOR': "SHAP plot for %s",
56
- 'FEATURE': "Feature %s",
57
- 'FEATURE_VALUE': "Feature value",
58
- 'FEATURE_VALUE_LOW': "Low",
59
- 'FEATURE_VALUE_HIGH': "High",
60
- 'JOINT_VALUE': "Joint SHAP value"
61
- }
62
-
63
- def shap_summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type="dot",
64
- color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True,
65
- color_bar=True, auto_size_plot=True, layered_violin_max_num_bins=20):
66
- """Create a SHAP summary plot, colored by feature values when they are provided.
67
-
68
- Parameters
69
- ----------
70
- shap_values : numpy.array
71
- Matrix of SHAP values (# samples x # features)
72
-
73
- features : numpy.array or pandas.DataFrame or list
74
- Matrix of feature values (# samples x # features) or a feature_names list as shorthand
75
-
76
- feature_names : list
77
- Names of the features (length # features)
78
-
79
- max_display : int
80
- How many top features to include in the plot (default is 20, or 7 for interaction plots)
81
-
82
- plot_type : "dot" (default) or "violin"
83
- What type of summary plot to produce
84
- """
85
-
86
- assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector."
87
-
88
- # default color:
89
- if color is None:
90
- color = "coolwarm" if plot_type == 'layered_violin' else "#ff0052"
91
-
92
- # convert from a DataFrame or other types
93
- if str(type(features)) == "<class 'pandas.core.frame.DataFrame'>":
94
- if feature_names is None:
95
- feature_names = features.columns
96
- features = features.values
97
- elif str(type(features)) == "<class 'list'>":
98
- if feature_names is None:
99
- feature_names = features
100
- features = None
101
- elif (features is not None) and len(features.shape) == 1 and feature_names is None:
102
- feature_names = features
103
- features = None
104
-
105
- if feature_names is None:
106
- feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
107
-
108
- mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
109
-
110
- # plotting SHAP interaction values
111
- if len(shap_values.shape) == 3:
112
- if max_display is None:
113
- max_display = 7
114
- else:
115
- max_display = min(len(feature_names), max_display)
116
-
117
- sort_inds = np.argsort(-np.abs(shap_values[:, :-1, :-1].sum(1)).sum(0))
118
-
119
- # get plotting limits
120
- delta = 1.0 / (shap_values.shape[1] ** 2)
121
- slow = np.nanpercentile(shap_values, delta)
122
- shigh = np.nanpercentile(shap_values, 100 - delta)
123
- v = max(abs(slow), abs(shigh))
124
- slow = -0.2
125
- shigh = 0.2
126
-
127
- # mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
128
- ax = mpl_fig.subplot(1, max_display, 1)
129
- proj_shap_values = shap_values[:, sort_inds[0], np.hstack((sort_inds, len(sort_inds)))]
130
- proj_shap_values[:, 1:] *= 2 # because off diag effects are split in half
131
- shap_summary_plot(
132
- proj_shap_values, features[:, sort_inds],
133
- feature_names=feature_names[sort_inds],
134
- sort=False, show=False, color_bar=False,
135
- auto_size_plot=False,
136
- max_display=max_display
137
- )
138
- pl.xlim((slow, shigh))
139
- pl.xlabel("")
140
- title_length_limit = 11
141
- pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit))
142
- for i in range(1, max_display):
143
- ind = sort_inds[i]
144
- pl.subplot(1, max_display, i + 1)
145
- proj_shap_values = shap_values[:, ind, np.hstack((sort_inds, len(sort_inds)))]
146
- proj_shap_values *= 2
147
- proj_shap_values[:, i] /= 2 # because only off diag effects are split in half
148
- shap_summary_plot(
149
- proj_shap_values, features[:, sort_inds],
150
- sort=False,
151
- feature_names=["" for i in range(features.shape[1])],
152
- show=False,
153
- color_bar=False,
154
- auto_size_plot=False,
155
- max_display=max_display
156
- )
157
- pl.xlim((slow, shigh))
158
- pl.xlabel("")
159
- if i == max_display // 2:
160
- pl.xlabel(labels['INTERACTION_VALUE'])
161
- pl.title(shorten_text(feature_names[ind], title_length_limit))
162
- pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
163
- pl.subplots_adjust(hspace=0, wspace=0.1)
164
- # if show:
165
- # # pl.show()
166
- return mpl_fig
167
-
168
- if max_display is None:
169
- max_display = 20
170
-
171
- if sort:
172
- # order features by the sum of their effect magnitudes
173
- feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
174
- feature_order = feature_order[-min(max_display, len(feature_order)):]
175
- else:
176
- feature_order = np.flip(np.arange(min(max_display, shap_values.shape[1] - 1)), 0)
177
-
178
- row_height = 0.4
179
- if auto_size_plot:
180
- pl.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5)
181
- pl.axvline(x=0, color="#999999", zorder=-1)
182
-
183
- if plot_type == "dot":
184
- for pos, i in enumerate(feature_order):
185
- pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
186
- shaps = shap_values[:, i]
187
- values = None if features is None else features[:, i]
188
- inds = np.arange(len(shaps))
189
- np.random.shuffle(inds)
190
- if values is not None:
191
- values = values[inds]
192
- shaps = shaps[inds]
193
- colored_feature = True
194
- try:
195
- values = np.array(values, dtype=np.float64) # make sure this can be numeric
196
- except:
197
- colored_feature = False
198
- N = len(shaps)
199
- # hspacing = (np.max(shaps) - np.min(shaps)) / 200
200
- # curr_bin = []
201
- nbins = 100
202
- quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
203
- inds = np.argsort(quant + np.random.randn(N) * 1e-6)
204
- layer = 0
205
- last_bin = -1
206
- ys = np.zeros(N)
207
- for ind in inds:
208
- if quant[ind] != last_bin:
209
- layer = 0
210
- ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
211
- layer += 1
212
- last_bin = quant[ind]
213
- ys *= 0.9 * (row_height / np.max(ys + 1))
214
-
215
- if features is not None and colored_feature:
216
- # trim the color range, but prevent the color range from collapsing
217
- vmin = np.nanpercentile(values, 5)
218
- vmax = np.nanpercentile(values, 95)
219
- if vmin == vmax:
220
- vmin = np.nanpercentile(values, 1)
221
- vmax = np.nanpercentile(values, 99)
222
- if vmin == vmax:
223
- vmin = np.min(values)
224
- vmax = np.max(values)
225
-
226
- assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!"
227
- nan_mask = np.isnan(values)
228
- pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin,
229
- vmax=vmax, s=16, alpha=alpha, linewidth=0,
230
- zorder=3, rasterized=len(shaps) > 500)
231
- pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
232
- cmap=red_blue, vmin=vmin, vmax=vmax, s=16,
233
- c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0,
234
- zorder=3, rasterized=len(shaps) > 500)
235
- else:
236
-
237
- pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3,
238
- color=color if colored_feature else "#777777", rasterized=len(shaps) > 500)
239
-
240
- elif plot_type == "violin":
241
- for pos, i in enumerate(feature_order):
242
- pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
243
-
244
- if features is not None:
245
- global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1)
246
- global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99)
247
- for pos, i in enumerate(feature_order):
248
- shaps = shap_values[:, i]
249
- shap_min, shap_max = np.min(shaps), np.max(shaps)
250
- rng = shap_max - shap_min
251
- xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100)
252
- if np.std(shaps) < (global_high - global_low) / 100:
253
- ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs)
254
- else:
255
- ds = gaussian_kde(shaps)(xs)
256
- ds /= np.max(ds) * 3
257
-
258
- values = features[:, i]
259
- window_size = max(10, len(values) // 20)
260
- smooth_values = np.zeros(len(xs) - 1)
261
- sort_inds = np.argsort(shaps)
262
- trailing_pos = 0
263
- leading_pos = 0
264
- running_sum = 0
265
- back_fill = 0
266
- for j in range(len(xs) - 1):
267
-
268
- while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]:
269
- running_sum += values[sort_inds[leading_pos]]
270
- leading_pos += 1
271
- if leading_pos - trailing_pos > 20:
272
- running_sum -= values[sort_inds[trailing_pos]]
273
- trailing_pos += 1
274
- if leading_pos - trailing_pos > 0:
275
- smooth_values[j] = running_sum / (leading_pos - trailing_pos)
276
- for k in range(back_fill):
277
- smooth_values[j - k - 1] = smooth_values[j]
278
- else:
279
- back_fill += 1
280
-
281
- vmin = np.nanpercentile(values, 5)
282
- vmax = np.nanpercentile(values, 95)
283
- if vmin == vmax:
284
- vmin = np.nanpercentile(values, 1)
285
- vmax = np.nanpercentile(values, 99)
286
- if vmin == vmax:
287
- vmin = np.min(values)
288
- vmax = np.max(values)
289
- pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=red_blue_solid, vmin=vmin, vmax=vmax,
290
- c=values, alpha=alpha, linewidth=0, zorder=1)
291
- # smooth_values -= nxp.nanpercentile(smooth_values, 5)
292
- # smooth_values /= np.nanpercentile(smooth_values, 95)
293
- smooth_values -= vmin
294
- if vmax - vmin > 0:
295
- smooth_values /= vmax - vmin
296
- for i in range(len(xs) - 1):
297
- if ds[i] > 0.05 or ds[i + 1] > 0.05:
298
- pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]],
299
- [pos - ds[i], pos - ds[i + 1]], color=red_blue_solid(smooth_values[i]),
300
- zorder=2)
301
-
302
- else:
303
- parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False,
304
- widths=0.7,
305
- showmeans=False, showextrema=False, showmedians=False)
306
-
307
- for pc in parts['bodies']:
308
- pc.set_facecolor(color)
309
- pc.set_edgecolor('none')
310
- pc.set_alpha(alpha)
311
-
312
- elif plot_type == "layered_violin": # courtesy of @kodonnell
313
- num_x_points = 200
314
- bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype(
315
- 'int') # the indices of the feature data corresponding to each bin
316
- shap_min, shap_max = np.min(shap_values[:, :-1]), np.max(shap_values[:, :-1])
317
- x_points = np.linspace(shap_min, shap_max, num_x_points)
318
-
319
- # loop through each feature and plot:
320
- for pos, ind in enumerate(feature_order):
321
- # decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles.
322
- # to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts.
323
- feature = features[:, ind]
324
- unique, counts = np.unique(feature, return_counts=True)
325
- if unique.shape[0] <= layered_violin_max_num_bins:
326
- order = np.argsort(unique)
327
- thesebins = np.cumsum(counts[order])
328
- thesebins = np.insert(thesebins, 0, 0)
329
- else:
330
- thesebins = bins
331
- nbins = thesebins.shape[0] - 1
332
- # order the feature data so we can apply percentiling
333
- order = np.argsort(feature)
334
- # x axis is located at y0 = pos, with pos being there for offset
335
- y0 = np.ones(num_x_points) * pos
336
- # calculate kdes:
337
- ys = np.zeros((nbins, num_x_points))
338
- for i in range(nbins):
339
- # get shap values in this bin:
340
- shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind]
341
- # if there's only one element, then we can't
342
- if shaps.shape[0] == 1:
343
- warnings.warn(
344
- "not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot."
345
- % (i, feature_names[ind]))
346
- # to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's
347
- # nothing to do if i == 0
348
- if i > 0:
349
- ys[i, :] = ys[i - 1, :]
350
- continue
351
- # save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors
352
- ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points)
353
- # scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will
354
- # do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1%
355
- # female, we want the 1% to appear a lot smaller.
356
- size = thesebins[i + 1] - thesebins[i]
357
- bin_size_if_even = features.shape[0] / nbins
358
- relative_bin_size = size / bin_size_if_even
359
- ys[i, :] *= relative_bin_size
360
- # now plot 'em. We don't plot the individual strips, as this can leave whitespace between them.
361
- # instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no
362
- # whitespace
363
- ys = np.cumsum(ys, axis=0)
364
- width = 0.8
365
- scale = ys.max() * 2 / width # 2 is here as we plot both sides of x axis
366
- for i in range(nbins - 1, -1, -1):
367
- y = ys[i, :] / scale
368
- c = pl.get_cmap(color)(i / (
369
- nbins - 1)) if color in pl.cm.datad else color # if color is a cmap, use it, otherwise use a color
370
- pl.fill_between(x_points, pos - y, pos + y, facecolor=c)
371
- pl.xlim(shap_min, shap_max)
372
-
373
- # draw the color bar
374
- if color_bar and features is not None and (plot_type != "layered_violin" or color in pl.cm.datad):
375
- import matplotlib.cm as cm
376
- m = cm.ScalarMappable(cmap=red_blue_solid if plot_type != "layered_violin" else pl.get_cmap(color))
377
- m.set_array([0, 1])
378
- cb = pl.colorbar(m, ticks=[0, 1], aspect=1000)
379
- cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
380
- cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0)
381
- cb.ax.tick_params(labelsize=11, length=0)
382
- cb.set_alpha(1)
383
- cb.outline.set_visible(False)
384
- bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
385
- cb.ax.set_aspect((bbox.height - 0.9) * 20)
386
- # cb.draw_all()
387
-
388
- pl.gca().xaxis.set_ticks_position('bottom')
389
- pl.gca().yaxis.set_ticks_position('none')
390
- pl.gca().spines['right'].set_visible(False)
391
- pl.gca().spines['top'].set_visible(False)
392
- pl.gca().spines['left'].set_visible(False)
393
- pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
394
- pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13)
395
- pl.gca().tick_params('y', length=20, width=0.5, which='major')
396
- pl.gca().tick_params('x', labelsize=11)
397
- pl.ylim(-1, len(feature_order))
398
- pl.xlabel(labels['VALUE'], fontsize=13)
399
- pl.tight_layout()
400
- # if show:
401
- # pl.show()
402
- return mpl_fig
403
-
404
-
405
-
406
-
407
-
408
-
409
- def approx_interactions(index, shap_values, X):
410
- """ Order other features by how much interaction they seem to have with the feature at the given index.
411
-
412
- This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction
413
- index values for SHAP see the interaction_contribs option implemented in XGBoost.
414
- """
415
-
416
- if X.shape[0] > 10000:
417
- a = np.arange(X.shape[0])
418
- np.random.shuffle(a)
419
- inds = a[:10000]
420
- else:
421
- inds = np.arange(X.shape[0])
422
-
423
- x = X[inds, index]
424
- srt = np.argsort(x)
425
- shap_ref = shap_values[inds, index]
426
- shap_ref = shap_ref[srt]
427
- inc = max(min(int(len(x) / 10.0), 50), 1)
428
- interactions = []
429
- for i in range(X.shape[1]):
430
- val_other = X[inds, i][srt].astype(np.float)
431
- v = 0.0
432
- if not (i == index or np.sum(np.abs(val_other)) < 1e-8):
433
- for j in range(0, len(x), inc):
434
- if np.std(val_other[j:j + inc]) > 0 and np.std(shap_ref[j:j + inc]) > 0:
435
- v += abs(np.corrcoef(shap_ref[j:j + inc], val_other[j:j + inc])[0, 1])
436
- interactions.append(v)
437
-
438
- return np.argsort(-np.abs(interactions))
439
-
440
-
441
-
442
-
443
-
444
-
445
-
446
- def shap_dependence_plot(ind, shap_values, features, feature_names=None, display_features=None,
447
- interaction_index="auto", color="#1E88E5", axis_color="#333333",
448
- dot_size=16, alpha=1, title=None, show=True):
449
- """
450
- Create a SHAP dependence plot, colored by an interaction feature.
451
-
452
- Parameters
453
- ----------
454
- ind : int
455
- Index of the feature to plot.
456
-
457
- shap_values : numpy.array
458
- Matrix of SHAP values (# samples x # features)
459
-
460
- features : numpy.array or pandas.DataFrame
461
- Matrix of feature values (# samples x # features)
462
-
463
- feature_names : list
464
- Names of the features (length # features)
465
-
466
- display_features : numpy.array or pandas.DataFrame
467
- Matrix of feature values for visual display (such as strings instead of coded values)
468
-
469
- interaction_index : "auto", None, or int
470
- The index of the feature used to color the plot.
471
- """
472
-
473
- # convert from DataFrames if we got any
474
- if str(type(features)).endswith("'pandas.core.frame.DataFrame'>"):
475
- if feature_names is None:
476
- feature_names = features.columns
477
- features = features.values
478
- if str(type(display_features)).endswith("'pandas.core.frame.DataFrame'>"):
479
- if feature_names is None:
480
- feature_names = display_features.columns
481
- display_features = display_features.values
482
- elif display_features is None:
483
- display_features = features
484
-
485
- if feature_names is None:
486
- feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
487
-
488
- # allow vectors to be passed
489
- if len(shap_values.shape) == 1:
490
- shap_values = np.reshape(shap_values, len(shap_values), 1)
491
- if len(features.shape) == 1:
492
- features = np.reshape(features, len(features), 1)
493
-
494
- def convert_name(ind):
495
- if type(ind) == str:
496
- nzinds = np.where(feature_names == ind)[0]
497
- if len(nzinds) == 0:
498
- print("Could not find feature named: " + ind)
499
- return None
500
- else:
501
- return nzinds[0]
502
- else:
503
- return ind
504
-
505
- ind = convert_name(ind)
506
-
507
- mpl_fig = pl.gcf()
508
- ax = mpl_fig.gca()
509
-
510
- # plotting SHAP interaction values
511
- if len(shap_values.shape) == 3 and len(ind) == 2:
512
- ind1 = convert_name(ind[0])
513
- ind2 = convert_name(ind[1])
514
- if ind1 == ind2:
515
- proj_shap_values = shap_values[:, ind2, :]
516
- else:
517
- proj_shap_values = shap_values[:, ind2, :] * 2 # off-diag values are split in half
518
-
519
- # TODO: remove recursion; generally the functions should be shorter for more maintainable code
520
- return shap_dependence_plot(
521
- ind1, proj_shap_values, features, feature_names=feature_names,
522
- interaction_index=ind2, display_features=display_features, show=False
523
- )
524
-
525
- assert shap_values.shape[0] == features.shape[0], \
526
- "'shap_values' and 'features' values must have the same number of rows!"
527
- assert shap_values.shape[1] == features.shape[1], \
528
- "'shap_values' must have the same number of columns as 'features'!"
529
-
530
- # get both the raw and display feature values
531
- xv = features[:, ind]
532
- xd = display_features[:, ind]
533
- s = shap_values[:, ind]
534
- if type(xd[0]) == str:
535
- name_map = {}
536
- for i in range(len(xv)):
537
- name_map[xd[i]] = xv[i]
538
- xnames = list(name_map.keys())
539
-
540
- # allow a single feature name to be passed alone
541
- if type(feature_names) == str:
542
- feature_names = [feature_names]
543
- name = feature_names[ind]
544
-
545
- # guess what other feature as the stongest interaction with the plotted feature
546
- if interaction_index == "auto":
547
- interaction_index = approx_interactions(ind, shap_values, features)[0]
548
- interaction_index = convert_name(interaction_index)
549
- categorical_interaction = False
550
-
551
- # get both the raw and display color values
552
- if interaction_index is not None:
553
- cv = features[:, interaction_index]
554
- cd = display_features[:, interaction_index]
555
- clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
556
- chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
557
- if type(cd[0]) == str:
558
- cname_map = {}
559
- for i in range(len(cv)):
560
- cname_map[cd[i]] = cv[i]
561
- cnames = list(cname_map.keys())
562
- categorical_interaction = True
563
- elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
564
- categorical_interaction = True
565
-
566
- # discritize colors for categorical features
567
- color_norm = None
568
- if categorical_interaction and clow != chigh:
569
- bounds = np.linspace(clow, chigh, chigh - clow + 2)
570
- color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
571
-
572
- # the actual scatter plot, TODO: adapt the dot_size to the number of data points?
573
- if interaction_index is not None:
574
- pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
575
- alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
576
- else:
577
- pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
578
- alpha=alpha, rasterized=len(xv) > 500)
579
-
580
- if interaction_index != ind and interaction_index is not None:
581
- # draw the color bar
582
- if type(cd[0]) == str:
583
- tick_positions = [cname_map[n] for n in cnames]
584
- if len(tick_positions) == 2:
585
- tick_positions[0] -= 0.25
586
- tick_positions[1] += 0.25
587
- cb = pl.colorbar(ticks=tick_positions)
588
- cb.set_ticklabels(cnames)
589
- else:
590
- cb = pl.colorbar()
591
-
592
- cb.set_label(feature_names[interaction_index], size=13)
593
- cb.ax.tick_params(labelsize=11)
594
- if categorical_interaction:
595
- cb.ax.tick_params(length=0)
596
- cb.set_alpha(1)
597
- cb.outline.set_visible(False)
598
- bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
599
- cb.ax.set_aspect((bbox.height - 0.7) * 20)
600
-
601
- # make the plot more readable
602
- if interaction_index != ind:
603
- pl.gcf().set_size_inches(7.5, 5)
604
- else:
605
- pl.gcf().set_size_inches(6, 5)
606
- # pl.xlabel(name, color=axis_color, fontsize=13)
607
- # pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
608
- if title is not None:
609
- pl.title(title, color=axis_color, fontsize=13)
610
- pl.gca().xaxis.set_ticks_position('bottom')
611
- pl.gca().yaxis.set_ticks_position('left')
612
- pl.gca().spines['right'].set_visible(False)
613
- pl.gca().spines['top'].set_visible(False)
614
- pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
615
- for spine in pl.gca().spines.values():
616
- spine.set_edgecolor(axis_color)
617
- if type(xd[0]) == str:
618
- pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
619
- # if show:
620
- # pl.show()
621
-
622
-
623
- if ind1 == ind2:
624
- pl.ylabel(labels['MAIN_EFFECT'] % feature_names[ind1])
625
- else:
626
- pl.ylabel(labels['INTERACTION_EFFECT'] % (feature_names[ind1], feature_names[ind2]))
627
-
628
- return mpl_fig, interaction_index
629
-
630
-
631
- # # if show:
632
- # # pl.show()
633
- # return
634
- # return mpl_fig
635
-
636
- # assert shap_values.shape[0] == features.shape[0], "'shap_values' and 'features' values must have the same number of rows!"
637
- # assert shap_values.shape[1] == features.shape[1] + 1, "'shap_values' must have one more column than 'features'!"
638
-
639
- # get both the raw and display feature values
640
- xv = features[:, ind]
641
- xd = display_features[:, ind]
642
- s = shap_values[:, ind]
643
- if type(xd[0]) == str:
644
- name_map = {}
645
- for i in range(len(xv)):
646
- name_map[xd[i]] = xv[i]
647
- xnames = list(name_map.keys())
648
-
649
- # allow a single feature name to be passed alone
650
- if type(feature_names) == str:
651
- feature_names = [feature_names]
652
- name = feature_names[ind]
653
-
654
- # guess what other feature as the stongest interaction with the plotted feature
655
- if interaction_index == "auto":
656
- interaction_index = approx_interactions(ind, shap_values, features)[0]
657
- interaction_index = convert_name(interaction_index)
658
- categorical_interaction = False
659
-
660
- # get both the raw and display color values
661
- if interaction_index is not None:
662
- cv = features[:, interaction_index]
663
- cd = display_features[:, interaction_index]
664
- clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
665
- chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
666
- if type(cd[0]) == str:
667
- cname_map = {}
668
- for i in range(len(cv)):
669
- cname_map[cd[i]] = cv[i]
670
- cnames = list(cname_map.keys())
671
- categorical_interaction = True
672
- elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
673
- categorical_interaction = True
674
-
675
- # discritize colors for categorical features
676
- color_norm = None
677
- if categorical_interaction and clow != chigh:
678
- bounds = np.linspace(clow, chigh, chigh - clow + 2)
679
- color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
680
-
681
- # the actual scatter plot, TODO: adapt the dot_size to the number of data points?
682
- if interaction_index is not None:
683
- pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
684
- alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
685
- else:
686
- pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
687
- alpha=alpha, rasterized=len(xv) > 500)
688
-
689
- if interaction_index != ind and interaction_index is not None:
690
- # draw the color bar
691
- if type(cd[0]) == str:
692
- tick_positions = [cname_map[n] for n in cnames]
693
- if len(tick_positions) == 2:
694
- tick_positions[0] -= 0.25
695
- tick_positions[1] += 0.25
696
- cb = pl.colorbar(ticks=tick_positions)
697
- cb.set_ticklabels(cnames)
698
- else:
699
- cb = pl.colorbar()
700
-
701
- cb.set_label(feature_names[interaction_index], size=13)
702
- cb.ax.tick_params(labelsize=11)
703
- if categorical_interaction:
704
- cb.ax.tick_params(length=0)
705
- cb.set_alpha(1)
706
- cb.outline.set_visible(False)
707
- bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
708
- cb.ax.set_aspect((bbox.height - 0.7) * 20)
709
-
710
- # make the plot more readable
711
- if interaction_index != ind:
712
- pl.gcf().set_size_inches(7.5, 5)
713
- else:
714
- pl.gcf().set_size_inches(6, 5)
715
- pl.xlabel(name, color=axis_color, fontsize=13)
716
- pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
717
- if title is not None:
718
- pl.title(title, color=axis_color, fontsize=13)
719
- pl.gca().xaxis.set_ticks_position('bottom')
720
- pl.gca().yaxis.set_ticks_position('left')
721
- pl.gca().spines['right'].set_visible(False)
722
- pl.gca().spines['top'].set_visible(False)
723
- pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
724
- for spine in pl.gca().spines.values():
725
- spine.set_edgecolor(axis_color)
726
- if type(xd[0]) == str:
727
- pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
728
- # if show:
729
- # pl.show()
730
- return mpl_fig, interaction_index