4.1 CNN learning rate finder

Implementing Learning Rate Finder as describued in Smith, L.N., 2017. Cyclical learning rates for training neural networks, in: 2017 IEEE Winter Conference on Applications of Computer Vision (WACV). IEEE, pp. 464–472. https://arxiv.org/abs/1506.01186]

if 'google.colab' in str(get_ipython()):
    from google.colab import drive
    drive.mount('/content/drive',  force_remount=False)
    !pip install mirzai
else:
Mounted at /content/drive
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mirzai
  Downloading mirzai-0.3.0-py3-none-any.whl (26 kB)
Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.12.1+cu113)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.3.5)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.7.3)
Collecting matplotlib>=3.5.1
  Downloading matplotlib-3.5.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (11.2 MB)
     |████████████████████████████████| 11.2 MB 10.7 MB/s 
Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from mirzai) (0.13.1+cu113)
Requirement already satisfied: fastcore in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.5.25)
Collecting captum
  Downloading captum-0.5.0-py3-none-any.whl (1.4 MB)
     |████████████████████████████████| 1.4 MB 64.4 MB/s 
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.0.2)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mirzai) (1.21.6)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from mirzai) (4.64.1)
Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (3.0.9)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (1.4.4)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (7.1.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (0.11.0)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (21.3)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.5.1->mirzai) (2.8.2)
Collecting fonttools>=4.22.0
  Downloading fonttools-4.37.1-py3-none-any.whl (957 kB)
     |████████████████████████████████| 957 kB 65.0 MB/s 
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from kiwisolver>=1.0.1->matplotlib>=3.5.1->mirzai) (4.1.1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7->matplotlib>=3.5.1->mirzai) (1.15.0)
Requirement already satisfied: pip in /usr/local/lib/python3.7/dist-packages (from fastcore->mirzai) (21.1.3)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->mirzai) (2022.2.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->mirzai) (3.1.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->mirzai) (1.1.0)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from torchvision->mirzai) (2.23.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->mirzai) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->mirzai) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->mirzai) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->mirzai) (2022.6.15)
Installing collected packages: fonttools, matplotlib, captum, mirzai
  Attempting uninstall: matplotlib
    Found existing installation: matplotlib 3.2.2
    Uninstalling matplotlib-3.2.2:
      Successfully uninstalled matplotlib-3.2.2
Successfully installed captum-0.5.0 fonttools-4.37.1 matplotlib-3.5.3 mirzai-0.3.0
Unable to display output for mime type(s): application/vnd.colab-display-data+json
from pathlib import Path
import pickle

# import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

from mirzai.data.loading import load_kssl
from mirzai.data.selection import (select_y, select_tax_order, select_X)
from mirzai.data.transform import log_transform_y
from mirzai.data.torch import DataLoaders, SNV_transform
from mirzai.training.cnn import (Model, weights_init)
from mirzai.training.cnn import Learner

# Deep Learning stack
import torch
from torch.optim import Adam
from torch.nn import MSELoss

from fastcore.transform import compose

import warnings
warnings.filterwarnings('ignore')

1. Load and transform

Load data

src_dir = '/content/drive/MyDrive/research/predict-k-mirs-dl/data/potassium'
fnames = ['spectra-features.npy', 'spectra-wavenumbers.npy', 
          'depth-order.npy', 'target.npy', 
          'tax-order-lu.pkl', 'spectra-id.npy']


X, X_names, depth_order, y, tax_lookup, X_id = load_kssl(src_dir, fnames=fnames)
transforms = [select_y, select_tax_order, select_X, log_transform_y]

data = X, y, X_id, depth_order
X, y, X_id, depth_order = compose(*transforms)(data)
print(X.shape)
(40132, 1764)

Create data loaders

split_ratio = 0.1

# Train/test split
X_train, X_test, y_train, y_test, tax_order_train, tax_order_test = train_test_split(X, 
                                                                                     y, 
                                                                                     depth_order[:,1], 
                                                                                     test_size=split_ratio,
                                                                                     random_state=42)

# Further train/valid split
X_train, X_valid, y_train, y_valid, tax_order_train, tax_order_valid = train_test_split(X_train, 
                                                                                      y_train,
                                                                                      tax_order_train, 
                                                                                      test_size=split_ratio, 
                                                                                      random_state=42)


dls = DataLoaders((X_train, y_train, tax_order_train), 
                  (X_valid, y_valid, tax_order_valid),
                  (X_test, y_test, tax_order_test), 
                  transform=SNV_transform())
training_generator, validation_generator, test_generator = dls.loaders()

2. Setup

# Is a GPU available?
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')
print(f'Runtime is: {device}')

n_epochs = 20
step_size_up = 5
criterion = MSELoss() # Mean Squared Error loss
Runtime is: cuda:0

4. Train

## LR finder
model = Model(X.shape[1], out_channel=16).to(device)

opt = Adam(model.parameters(), lr=1e-4)
model = model.apply(weights_init)

learner = Learner(model, criterion, opt, n_epochs=n_epochs, 
                  scheduler=None, early_stopper=None,
                  tax_lookup=tax_lookup.values(), verbose=True)

lrs, losses = learner.lr_finder(training_generator, end=0.1, n_epochs=8)
Epoch: 0
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7

5. Save & load

dest_dir = Path('/content/drive/MyDrive/research/predict-k-mirs-dl/dumps/cnn/lr_finder')
with open(dest_dir/f'cnn-lr.pickle', 'wb') as f: 
    pickle.dump([lrs, losses], f)
src_dir = Path('/content/drive/MyDrive/research/predict-k-mirs-dl/dumps/cnn/lr_finder')
with open(src_dir/f'cnn-lr.pickle', 'rb') as f: 
    lrs, losses = pickle.load(f)
pd.DataFrame({'learning_rate': lrs, 'loss': losses}).plot(x='learning_rate', y='loss', logx=True);