if 'google.colab' in str(get_ipython()):
from google.colab import drive
'/content/drive', force_remount=False)
drive.mount(!pip install mirzai
else:
6.2. GradientShap values correlation
Visualizing the (Pearson) correlation betwee GradientShap values aggregated by Soil Taxonomy Orders
# Python utils
from pathlib import Path
import pickle
# mirzai utilities
from mirzai.data.loading import load_kssl
from mirzai.data.selection import (select_y, select_tax_order, select_X)
from mirzai.data.transform import log_transform_y
from mirzai.vis.core import (centimeter, PRIMARY_COLOR,
set_style, DEFAULT_STYLE)
from fastcore.transform import compose
# Data science stack
import numpy as np
# Data vis.
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import warnings
'ignore') warnings.filterwarnings(
Load and transform
= 'data'
src_dir = ['spectra-features.npy', 'spectra-wavenumbers.npy',
fnames 'depth-order.npy', 'target.npy',
'tax-order-lu.pkl', 'spectra-id.npy']
= load_kssl(src_dir, fnames=fnames)
X, X_names, depth_order, y, tax_lookup, X_id
= X, y, X_id, depth_order
data
= [select_y, select_tax_order, select_X, log_transform_y]
transforms = compose(*transforms)(data) X, y, X_id, depth_order
Setup
The required GradShap values are computed (then saved) as shown in GradientShap values notebook. Here we will simply load them.
= Path('dumps/cnn/shaps')
src_dir = pickle.load(open(src_dir/'shap_by_orders_02_09_2022.pickle', "rb")) shap_by_orders, _
We compute the Pearson correlation coefficient between the average GradientShap values of Soil Taxonomy Orders.
shap_by_orders
[OrderedDict([('shap',
array([ 1.8393685e-06, -3.4975756e-07, 1.0674637e-07, ...,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'alfisols')]),
OrderedDict([('shap',
array([-5.8400028e-06, -1.0966338e-05, -1.0213902e-06, ...,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'all')]),
OrderedDict([('shap',
array([ 2.3224013e-05, 1.5094614e-04, -2.2070333e-05, ...,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'andisols')]),
OrderedDict([('shap',
array([2.0230336e-05, 9.7760840e-06, 9.9760737e-06, ..., 0.0000000e+00,
0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'aridisols')]),
OrderedDict([('shap',
array([-7.0780785e-05, -1.4203938e-04, 1.2610339e-05, ...,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'entisols')]),
OrderedDict([('shap',
array([ 4.5924044e-06, 4.4373726e-05, -1.3388793e-05, ...,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'gelisols')]),
OrderedDict([('shap',
array([ 1.8504910e-05, 4.4637756e-05, -1.5806931e-05, ...,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'histosols')]),
OrderedDict([('shap',
array([ 2.0495969e-05, 3.4390621e-05, -4.1421335e-06, ...,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'inceptisols')]),
OrderedDict([('shap',
array([1.8545827e-05, 1.5292404e-05, 3.3629021e-06, ..., 0.0000000e+00,
0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'mollisols')]),
OrderedDict([('shap',
array([ 9.6384401e-04, 2.6268021e-03, -3.9017643e-05, ...,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'oxisols')]),
OrderedDict([('shap',
array([-4.8168564e-05, -2.9730183e-05, -3.7888144e-06, ...,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'spodosols')]),
OrderedDict([('shap',
array([ 9.3750474e-05, 1.1947515e-04, -1.9311596e-05, ...,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'ultisols')]),
OrderedDict([('shap',
array([-6.108464e-06, -9.995457e-06, 8.304874e-07, ..., 0.000000e+00,
0.000000e+00, 0.000000e+00], dtype=float32)),
('order', 'undefined')]),
OrderedDict([('shap',
array([1.7481134e-05, 2.0790496e-05, 2.8002933e-07, ..., 0.0000000e+00,
0.0000000e+00, 0.0000000e+00], dtype=float32)),
('order', 'vertisols')])]
= ['all', 'undefined', 'mollisols', 'alfisols', 'inceptisols', 'ultisols',
tax_order_sorted 'entisols', 'aridisols', 'andisols','vertisols', 'histosols',
'spodosols', 'gelisols', 'oxisols']
= []
shap_by_orders_sorted for tax in tax_order_sorted:
list(filter(lambda x: x['order'] == tax, shap_by_orders))[0]) shap_by_orders_sorted.append(
# Correlation
= np.corrcoef(np.array([s['shap'] for s in shap_by_orders_sorted]))
shaps_corr = [s['order'].capitalize() for s in shap_by_orders_sorted] corr_labels
Plot
def plot_shaps_corr(shaps, labels, figsize=(6*centimeter, 6*centimeter), dpi=600):
# Styles
= plt.rcParams
p "axes.spines.bottom"] = False
p["axes.grid"] = False
p["xtick.major.size"] = 3
p["xtick.major.width"] = 0.5
p["xtick.minor.size"] = 1
p["xtick.minor.width"] = 0.25
p["xtick.minor.visible"] = False
p["xtick.labelsize"] = 4
p["ytick.labelsize"] = 4
p["ytick.labelleft"] = True
p["ytick.direction"] = "out"
p["ytick.major.size"] = 3
p["ytick.major.width"] = 0.5
p["ytick.minor.size"] = 1
p["ytick.minor.width"] = 0.25
p["ytick.minor.visible"] = False
p[
# Layout
= plt.subplots(ncols=1, nrows=1, figsize=figsize, dpi=dpi)
fig, ax
= mcolors.TwoSlopeNorm(vmin=-1, vcenter=0., vmax=1)
norm = dict(cmap='RdBu_r', norm=norm, interpolation=None)
props
= np.tril(shaps, k=0)
shaps == 0.0] = np.nan
shaps[shaps
= ax.imshow(shaps, **props)
im len(labels)), labels=labels)
ax.set_xticks(np.arange(len(labels)), labels=labels)
ax.set_yticks(np.arange(=45, ha="right",
plt.setp(ax.get_xticklabels(), rotation="anchor")
rotation_mode
for i in range(len(shaps)):
for j in range(len(shaps)):
if np.isnan(shaps[i, j]):
pass
else:
= 'w' if np.abs(shaps[i, j]) > 0.3 else '#333'
color = ax.text(j, i, "{:.2f}".format(shaps[i, j]),
text ="center", va="center", color=color, size=3)
ha
= fig.add_axes([0.5, 0.9, 0.3, 0.02])
cax = plt.colorbar(im, cax=cax, orientation='horizontal')
clb 'bottom')
clb.ax.xaxis.set_ticks_position(=4)
clb.ax.tick_params(labelsizeFalse)
clb.outline.set_visible('Correlation Coefficient', size=4)
clb.ax.set_title(
plt.tight_layout()
#FIG_PATH = Path('nameofyourfolder')
= Path('images/')
FIG_PATH
set_style(DEFAULT_STYLE)
plot_shaps_corr(shaps_corr, corr_labels)
# To save/export it
/'gradshap-order-corr.png', dpi=600, transparent=True, format='png') plt.savefig(FIG_PATH