6.1. GradientShap values

Calculating features (wavenumbers) importance for the CNN

if 'google.colab' in str(get_ipython()):
    from google.colab import drive
    drive.mount('/content/drive',  force_remount=False)
    !pip install mirzai
else:
Mounted at /content/drive
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mirzai
  Downloading mirzai-0.2.10-py3-none-any.whl (25 kB)
Collecting matplotlib>=3.5.1
  Downloading matplotlib-3.5.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (11.2 MB)
     |████████████████████████████████| 11.2 MB 12.5 MB/s 
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from mirzai) (4.64.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.21.6)
Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from mirzai) (0.13.1+cu113)
Collecting captum
  Downloading captum-0.5.0-py3-none-any.whl (1.4 MB)
     |████████████████████████████████| 1.4 MB 81.6 MB/s 
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.3.5)
Requirement already satisfied: fastcore in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.5.22)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.0.2)
Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.12.1+cu113)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.7.3)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (7.1.2)
Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (3.0.9)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (2.8.2)
Collecting fonttools>=4.22.0
  Downloading fonttools-4.37.1-py3-none-any.whl (957 kB)
     |████████████████████████████████| 957 kB 48.6 MB/s 
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (1.4.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (21.3)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from kiwisolver>=1.0.1->matplotlib>=3.5.1->mirzai) (4.1.1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7->matplotlib>=3.5.1->mirzai) (1.15.0)
Requirement already satisfied: pip in /usr/local/lib/python3.7/dist-packages (from fastcore->mirzai) (21.1.3)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->mirzai) (2022.2.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->mirzai) (3.1.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->mirzai) (1.1.0)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from torchvision->mirzai) (2.23.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->mirzai) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->mirzai) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->mirzai) (2022.6.15)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->mirzai) (2.10)
Installing collected packages: fonttools, matplotlib, captum, mirzai
  Attempting uninstall: matplotlib
    Found existing installation: matplotlib 3.2.2
    Uninstalling matplotlib-3.2.2:
      Successfully uninstalled matplotlib-3.2.2
Successfully installed captum-0.5.0 fonttools-4.37.1 matplotlib-3.5.3 mirzai-0.2.10
Unable to display output for mime type(s): application/vnd.colab-display-data+json
# Python utils
import math
from collections import OrderedDict
from tqdm.auto import tqdm
from pathlib import Path
import pickle

# mirzai utilities
from mirzai.data.loading import load_kssl
from mirzai.data.selection import (select_y, select_tax_order, select_X)
from mirzai.data.transform import (log_transform_y, SNV)
from mirzai.data.selection import get_y_by_order
from mirzai.training.cnn import Model
from mirzai.vis.core import (centimeter, set_style, DEFAULT_STYLE)

from fastcore.transform import compose

# Data science stack
import pandas as pd
import numpy as np
from numpy.random import randint
from sklearn.model_selection import train_test_split

# Deep Learning stack
import torch

# Data vis.
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from matplotlib import ticker

# Interpretability
import captum
from captum.attr import GradientShap

import warnings
warnings.filterwarnings('ignore')

Load and transform

src_dir = '/content/drive/MyDrive/research/predict-k-mirs-dl/data/potassium'

fnames = ['spectra-features.npy', 'spectra-wavenumbers.npy', 
          'depth-order.npy', 'target.npy', 
          'tax-order-lu.pkl', 'spectra-id.npy']

X, X_names, depth_order, y, tax_lookup, X_id = load_kssl(src_dir, fnames=fnames)

data = X, y, X_id, depth_order

transforms = [select_y, select_tax_order, select_X, log_transform_y]
X, y, X_id, depth_order = compose(*transforms)(data)

Experiment (GPU required)

Utilities

def X_to_torch(X, preprocessing_fn=SNV(), device='cuda:0'):
    if preprocessing_fn:
        X = preprocessing_fn.fit_transform(X)
    X = X.reshape(X.shape[-2], 1, -1)
    return torch.tensor(X).to(device)  

def gradShap(X_baseline, X_inspect, model, n_baseline=100, 
             kwargs={'n_samples': 5, 'return_convergence_delta': True}):

    idx_baseline = randint(X_baseline.shape[0], size=n_baseline)
    X_baseline = X_baseline[idx_baseline, :]

    gs = GradientShap(model, multiply_by_inputs=True)
    shaps = []
    for i in tqdm(range(len(X_inspect))):
        gs_attr_test, delta = gs.attribute(X_to_torch(X_inspect[[i],:]), 
                                           baselines=X_to_torch(X_baseline), 
                                           **kwargs)
        shaps.append(gs_attr_test.cpu().detach().numpy().ravel())

    return np.array(shaps)
def reduce(dfs, colname='shap'):
    df = pd.concat(dfs)
    df.index.name = 'wn'
    df = df.reset_index().groupby(['order', 'wn']).median()
    df.reset_index(inplace=True)
    df_reduced = []
    for name, group in df.groupby('order'):
        df_reduced.append(OrderedDict(
            {colname: group.sort_values(by='wn', ascending=False)[colname].to_numpy(), 
            'order': name}))
    return df_reduced

Setup

# Is a GPU available?
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')
print(f'Runtime is: {device}')
Runtime is: cuda:0

Run

Computes GradientShap values by Soil Taxonomy order.

seeds = range(20)
split_ratio = 0.1
src_dir = Path('/content/drive/MyDrive/research/predict-k-mirs-dl/dumps/cnn/train_eval/all/models')

shap_by_order = []
X_mean_by_order = []

for seed in seeds:
    print(80*'-')
    print(f'Seed: {seed}')
    print(80*'-')

    # Train/test split
    data = train_test_split(X, y, depth_order, test_size=split_ratio, random_state=seed)
    X_train, X_test, y_train, y_test, depth_order_train, depth_order_test = data

    # Further Train/Valid split 
    data = train_test_split(X_train, y_train, depth_order_train, 
                            test_size=split_ratio, random_state=seed)
    X_train, X_valid, y_train, y_valid, depth_order_train, depth_order_valid = data    

    # load model
    model = Model(X.shape[1], out_channel=16).to(device)
    fname = f'model-seed-{seed}.pt'
    if device.type == 'cpu':
        model.load_state_dict(torch.load(src_dir/fname, map_location=torch.device('cpu')))
    else:
        model.load_state_dict(torch.load(src_dir/fname))
    # Params are not learnable in "eval" model & Dropout is disabled
    model.eval()

    orders_label = ['all'] + list(tax_lookup.keys())

    # Compute mean GradientShap value by orders
    for order in tqdm(orders_label):
        if order == 'all': 
            X_inspect = X_valid
        else:
            idx = tax_lookup[order]
            mask = depth_order_valid[:, 1] == idx
            X_inspect = X_valid[mask, :]
        
        X_mean = np.mean(X_inspect, axis=0)
        shaps = gradShap(X_train, X_inspect, model, n_baseline=100)
        shap_by_order.append(pd.DataFrame({'shap': np.mean(shaps, axis=0), 'order': order}, 
                                          index=X_names))
        X_mean_by_order.append(pd.DataFrame({'X': X_mean, 'order': order}, 
                                         index=X_names))
--------------------------------------------------------------------------------
Seed: 0
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 1
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 2
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 3
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 4
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 5
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 6
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 7
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 8
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 9
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 10
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 11
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 12
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 13
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 14
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 15
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 16
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 17
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 18
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 19
--------------------------------------------------------------------------------
shap_by_order, X_mean_by_order = reduce(shap_by_order), reduce(X_mean_by_order, colname='X')
# Save it if required
#dest_dir = Path('your_dumps_path')
dest_dir = Path('/content/drive/MyDrive/research/predict-k-mirs-dl/dumps/cnn/shaps')
fname = 'shap_by_orders_02_09_2022.pickle'
with open(dest_dir/fname, 'wb') as f: 
    pickle.dump((shap_by_order, X_mean_by_order), f)

Plot

Utilities

def prettify_label(label, tax_lut):
    return tax_lut[label].capitalize()

Reload

src_dir = Path('./files/dumps')
src_dir = Path('/content/drive/MyDrive/research/predict-k-mirs-dl/dumps/cnn/shaps')
fname = 'shap_by_orders_02_09_2022.pickle'
shap_by_order, X_mean_by_order  = pickle.load(open(src_dir/fname, "rb"))
tax_pretty_lut = OrderedDict({'all': 'all', 
                              'undefined': 'undefined', 
                              'mollisols': 'molli.', 
                              'alfisols': 'alfi.', 
                              'inceptisols': 'incepti.', 
                              'ultisols': 'ulti.', 
                              'entisols': 'enti.', 
                              'aridisols': 'aridi.',   
                              'andisols': 'andi.',
                              'vertisols': 'verti.',
                              'histosols': 'histo.',
                              'spodosols': 'spodo.',
                              'gelisols': 'geli.', 
                              'oxisols': 'oxi.'})
def plot_shaps_by_orders(attr_values, X, X_names, tax_pretty, diverging=False,
                         annotate=True, figsize=(16*centimeter,6*centimeter), dpi=600):
    # Styles
    p = plt.rcParams
    p["axes.spines.bottom"] = False
    p["axes.grid"] = False
    p["xtick.labelsize"] = 6
    p["xtick.direction"] = "in"
    p["xtick.major.size"] = 3
    p["xtick.major.width"] = 0.5
    p["xtick.minor.size"] = 1
    p["xtick.minor.width"] = 0.25
    p["ytick.left"] = False
    p["ytick.labelleft"] = False
    p["ytick.labelright"] = False
    p["ytick.major.size"] = 3
    p["ytick.major.width"] = 0.5
    p["ytick.minor.size"] = 1
    p["ytick.minor.width"] = 0.25
    p["ytick.minor.visible"] = False

    # Layout 
    fig, axes = plt.subplots(ncols=1, nrows=len(attr_values),
                             sharey=False, figsize=figsize, dpi=dpi) 
    
    # Calculate color scale adapted to grid resolution
    for i, (label, values) in enumerate(tax_pretty.items()):
        shap = list(filter(lambda x: x['order'] == label, attr_values))[0]['shap']
        if not diverging:
            shap = np.absolute(shap)
        axes[i].set_xlim(np.max(X_names), np.min(X_names))
        title = prettify_label(label, tax_pretty_lut) 

        if diverging:
            mask_pos = shap > 0
            attr_pos = np.copy(shap)
            attr_neg = np.copy(shap)
            attr_pos[~mask_pos] = 0   
            attr_neg[mask_pos] = 0   
            axes[i].bar(X_names, attr_pos, width=3, color='#0571b0', label='GradientShap > 0')
            axes[i].bar(X_names, attr_neg, width=3, color='#ca0020', label='GradientShap < 0')
        else:
            axes[i].bar(X_names, shap, width=2, color='black')
        
        axes[i].yaxis.set_major_formatter(FormatStrFormatter('%.2f')) 
        axes[i].set_ylabel(f'{title}')
        axes[i].get_yaxis().set_ticks([])

        ax_twin = axes[i].twinx()
        X_mean = list(filter(lambda x: x['order'] == label, X))[0]['X']
        ax_twin.plot(X_names, X_mean, c='#555', 
                     alpha=1, lw=0.5, ls='--', zorder=-1, 
                     label='Mean spectrum (Absorbance)') 
        ax_twin.get_yaxis().set_ticks([])

        axes[i].xaxis.set_major_locator(ticker.MaxNLocator(20))
        axes[i].xaxis.set_minor_locator(ticker.MaxNLocator(80))

    handles_all = []
    labels_all = []
    for ax in [axes[0], ax_twin]:
        handles, labs = ax.get_legend_handles_labels()
        labels_all += labs
        handles_all += handles
    
    fig.legend(handles_all, labels_all, 
               frameon=False, ncol=5, loc='upper center',  borderaxespad=0.1) 
   
   # Ornaments
    axes.flat[-1].set_xlabel('Wavenumber ($cm^{-1}$) →', loc='right')
    plt.tight_layout()
#FIG_PATH = Path('nameofyourfolder')
FIG_PATH = Path('/content/drive/MyDrive/research/predict-k-mirs-dl/img')
fname = 'gradshap-order-02092022.png'

set_style(DEFAULT_STYLE)
plot_shaps_by_orders(shap_by_order, X_mean_by_order, X_names, tax_pretty_lut,
                     figsize=(16*centimeter, 20*centimeter), diverging=True)


plt.savefig(FIG_PATH/fname, dpi=600, transparent=True, format='png')