Computes GradientShap values by Soil Taxonomy order.
seeds = range (20 )
split_ratio = 0.1
src_dir = Path('/content/drive/MyDrive/research/predict-k-mirs-dl/dumps/cnn/train_eval/all/models' )
shap_by_order = []
X_mean_by_order = []
for seed in seeds:
print (80 * '-' )
print (f'Seed: { seed} ' )
print (80 * '-' )
# Train/test split
data = train_test_split(X, y, depth_order, test_size= split_ratio, random_state= seed)
X_train, X_test, y_train, y_test, depth_order_train, depth_order_test = data
# Further Train/Valid split
data = train_test_split(X_train, y_train, depth_order_train,
test_size= split_ratio, random_state= seed)
X_train, X_valid, y_train, y_valid, depth_order_train, depth_order_valid = data
# load model
model = Model(X.shape[1 ], out_channel= 16 ).to(device)
fname = f'model-seed- { seed} .pt'
if device.type == 'cpu' :
model.load_state_dict(torch.load(src_dir/ fname, map_location= torch.device('cpu' )))
else :
model.load_state_dict(torch.load(src_dir/ fname))
# Params are not learnable in "eval" model & Dropout is disabled
model.eval ()
orders_label = ['all' ] + list (tax_lookup.keys())
# Compute mean GradientShap value by orders
for order in tqdm(orders_label):
if order == 'all' :
X_inspect = X_valid
else :
idx = tax_lookup[order]
mask = depth_order_valid[:, 1 ] == idx
X_inspect = X_valid[mask, :]
X_mean = np.mean(X_inspect, axis= 0 )
shaps = gradShap(X_train, X_inspect, model, n_baseline= 100 )
shap_by_order.append(pd.DataFrame({'shap' : np.mean(shaps, axis= 0 ), 'order' : order},
index= X_names))
X_mean_by_order.append(pd.DataFrame({'X' : X_mean, 'order' : order},
index= X_names))
--------------------------------------------------------------------------------
Seed: 0
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 1
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 2
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 3
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 4
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 5
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 6
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 7
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 8
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 9
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 10
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 11
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 12
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 13
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 14
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 15
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 16
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 17
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 18
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Seed: 19
--------------------------------------------------------------------------------