if 'google.colab' in str(get_ipython()):
from google.colab import drive
'/content/drive', force_remount=False)
drive.mount(!pip install mirzai
else:
Training & validation (CNN)
Various utilities function to train and evaluate the Convolutional Neural Network model
Model
Model (input_dim, in_channel=1, out_channel=16, is_classifier=False, dropout=0.4)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to
, etc.
.. note:: As per the example above, an __init__()
call to the parent class must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
weights_init
weights_init (m)
Learner
Learner (model, criterion=MSELoss(), opt=<class 'torch.optim.adam.Adam'>, n_epochs=50, scheduler=None, early_stopper=None, tax_lookup=range(0, 13), verbose=True)
Initialize self. See help(type(self)) for accurate signature.
Learners
Learners (model, tax_lookup, seeds=range(0, 20), device='cpu', verbose=True, split_ratio=0.1)
Initialize self. See help(type(self)) for accurate signature.
How to use the Model, Learner and Learners?
1. Load data
= 'test'
src_dir = ['spectra-features-smp.npy', 'spectra-wavenumbers-smp.npy',
fnames 'depth-order-smp.npy', 'target-smp.npy',
'tax-order-lu-smp.pkl', 'spectra-id-smp.npy']
# Or real data
#src_dir = '../_data'
#fnames = ['spectra-features.npy', 'spectra-wavenumbers.npy',
# 'depth-order.npy', 'target.npy',
# 'tax-order-lu.pkl', 'spectra-id.npy']
= load_kssl(src_dir, fnames=fnames)
X, X_names, depth_order, y, tax_lookup, X_id = [select_y, select_tax_order, select_X, log_transform_y]
transforms
= X, y, X_id, depth_order
data = compose(*transforms)(data)
X, y, X_id, depth_order print(X.shape)
(79, 1764)
2. Configure
# Is a GPU available?
= torch.cuda.is_available()
use_cuda = torch.device('cuda:0' if use_cuda else 'cpu')
device print(f'Runtime is: {device}')
= {
params_scheduler 'base_lr': 3e-5,
'max_lr': 1e-3,
'step_size_up': 5,
'mode': 'triangular',
'cycle_momentum': False
}
= 21
n_epochs = range(2) seeds
Runtime is: cpu
3. Train
# Replace following Paths with yours
= Path('test/dumps-test/cnn/train_eval/all/losses')
dest_dir_loss = Path('test/dumps-test/cnn/train_eval/all/models')
dest_dir_model
= Learners(Model, tax_lookup, seeds=seeds, split_ratio=0.1, device=device)
learners -1]),
learners.train((X, y, depth_order[:, =dest_dir_loss,
dest_dir_loss=dest_dir_model,
dest_dir_model=n_epochs,
n_epochs=params_scheduler) sc_kwargs
--------------------------------------------------------------------------------
Seed: 0
--------------------------------------------------------------------------------
------------------------------
Epoch: 0
Training loss: 0.2095176726579666 | Validation loss: 0.28754910826683044
Validation loss (ends of cycles): [0.28754911]
------------------------------
Epoch: 1
Training loss: 0.20893988013267517 | Validation loss: 0.2862740755081177
Validation loss (ends of cycles): [0.28754911]
------------------------------
Epoch: 2
Training loss: 0.20712845027446747 | Validation loss: 0.28408676385879517
Validation loss (ends of cycles): [0.28754911]
------------------------------
Epoch: 3
Training loss: 0.20162831246852875 | Validation loss: 0.2809036374092102
Validation loss (ends of cycles): [0.28754911]
------------------------------
Epoch: 4
Training loss: 0.1960785835981369 | Validation loss: 0.2767537534236908
Validation loss (ends of cycles): [0.28754911]
------------------------------
Epoch: 5
Training loss: 0.18827631324529648 | Validation loss: 0.27165573835372925
Validation loss (ends of cycles): [0.28754911]
------------------------------
Epoch: 6
Training loss: 0.18386302888393402 | Validation loss: 0.26757434010505676
Validation loss (ends of cycles): [0.28754911]
------------------------------
Epoch: 7
Training loss: 0.17783822119235992 | Validation loss: 0.26450157165527344
Validation loss (ends of cycles): [0.28754911]
------------------------------
Epoch: 8
Training loss: 0.17546673864126205 | Validation loss: 0.26242998242378235
Validation loss (ends of cycles): [0.28754911]
------------------------------
Epoch: 9
Training loss: 0.17405737936496735 | Validation loss: 0.26132825016975403
Validation loss (ends of cycles): [0.28754911]
------------------------------
Epoch: 10
Training loss: 0.1731555312871933 | Validation loss: 0.2611456513404846
Validation loss (ends of cycles): [0.28754911 0.26114565]
------------------------------
Epoch: 11
Training loss: 0.171500563621521 | Validation loss: 0.26005008816719055
Validation loss (ends of cycles): [0.28754911 0.26114565]
------------------------------
Epoch: 12
Training loss: 0.17028532177209854 | Validation loss: 0.2580767273902893
Validation loss (ends of cycles): [0.28754911 0.26114565]
------------------------------
Epoch: 13
Training loss: 0.16719596087932587 | Validation loss: 0.2550870180130005
Validation loss (ends of cycles): [0.28754911 0.26114565]
------------------------------
Epoch: 14
Training loss: 0.16363266110420227 | Validation loss: 0.25126326084136963
Validation loss (ends of cycles): [0.28754911 0.26114565]
------------------------------
Epoch: 15
Training loss: 0.15928955376148224 | Validation loss: 0.24669373035430908
Validation loss (ends of cycles): [0.28754911 0.26114565]
------------------------------
Epoch: 16
Training loss: 0.1538691222667694 | Validation loss: 0.24314354360103607
Validation loss (ends of cycles): [0.28754911 0.26114565]
------------------------------
Epoch: 17
Training loss: 0.1488434001803398 | Validation loss: 0.24059247970581055
Validation loss (ends of cycles): [0.28754911 0.26114565]
------------------------------
Epoch: 18
Training loss: 0.14692240208387375 | Validation loss: 0.2389926314353943
Validation loss (ends of cycles): [0.28754911 0.26114565]
------------------------------
Epoch: 19
Training loss: 0.14454391598701477 | Validation loss: 0.238305926322937
Validation loss (ends of cycles): [0.28754911 0.26114565]
------------------------------
Epoch: 20
Training loss: 0.14511099457740784 | Validation loss: 0.23853451013565063
Validation loss (ends of cycles): [0.28754911 0.26114565 0.23853451]
--------------------------------------------------------------------------------
Seed: 1
--------------------------------------------------------------------------------
------------------------------
Epoch: 0
Training loss: 0.25518064200878143 | Validation loss: 0.20745772123336792
Validation loss (ends of cycles): [0.20745772]
------------------------------
Epoch: 1
Training loss: 0.2529909759759903 | Validation loss: 0.20660457015037537
Validation loss (ends of cycles): [0.20745772]
------------------------------
Epoch: 2
Training loss: 0.25137291848659515 | Validation loss: 0.20521025359630585
Validation loss (ends of cycles): [0.20745772]
------------------------------
Epoch: 3
Training loss: 0.2424357831478119 | Validation loss: 0.20297007262706757
Validation loss (ends of cycles): [0.20745772]
------------------------------
Epoch: 4
Training loss: 0.23724160343408585 | Validation loss: 0.2000746876001358
Validation loss (ends of cycles): [0.20745772]
------------------------------
Epoch: 5
Training loss: 0.22914429008960724 | Validation loss: 0.19651295244693756
Validation loss (ends of cycles): [0.20745772]
------------------------------
Epoch: 6
Training loss: 0.22360220551490784 | Validation loss: 0.1936747431755066
Validation loss (ends of cycles): [0.20745772]
------------------------------
Epoch: 7
Training loss: 0.21922896057367325 | Validation loss: 0.19150234758853912
Validation loss (ends of cycles): [0.20745772]
------------------------------
Epoch: 8
Training loss: 0.2143341824412346 | Validation loss: 0.19003605842590332
Validation loss (ends of cycles): [0.20745772]
------------------------------
Epoch: 9
Training loss: 0.2121415138244629 | Validation loss: 0.18925482034683228
Validation loss (ends of cycles): [0.20745772]
------------------------------
Epoch: 10
Training loss: 0.2118835374712944 | Validation loss: 0.18918269872665405
Validation loss (ends of cycles): [0.20745772 0.1891827 ]
------------------------------
Epoch: 11
Training loss: 0.21091558784246445 | Validation loss: 0.18842312693595886
Validation loss (ends of cycles): [0.20745772 0.1891827 ]
------------------------------
Epoch: 12
Training loss: 0.2112690731883049 | Validation loss: 0.1870431900024414
Validation loss (ends of cycles): [0.20745772 0.1891827 ]
------------------------------
Epoch: 13
Training loss: 0.20744766294956207 | Validation loss: 0.1849963366985321
Validation loss (ends of cycles): [0.20745772 0.1891827 ]
------------------------------
Epoch: 14
Training loss: 0.20407014340162277 | Validation loss: 0.18240754306316376
Validation loss (ends of cycles): [0.20745772 0.1891827 ]
------------------------------
Epoch: 15
Training loss: 0.19959653913974762 | Validation loss: 0.1791967898607254
Validation loss (ends of cycles): [0.20745772 0.1891827 ]
------------------------------
Epoch: 16
Training loss: 0.19397148489952087 | Validation loss: 0.17683245241641998
Validation loss (ends of cycles): [0.20745772 0.1891827 ]
------------------------------
Epoch: 17
Training loss: 0.1890103444457054 | Validation loss: 0.17504656314849854
Validation loss (ends of cycles): [0.20745772 0.1891827 ]
------------------------------
Epoch: 18
Training loss: 0.18793968856334686 | Validation loss: 0.1737639605998993
Validation loss (ends of cycles): [0.20745772 0.1891827 ]
------------------------------
Epoch: 19
Training loss: 0.18513494729995728 | Validation loss: 0.17320512235164642
Validation loss (ends of cycles): [0.20745772 0.1891827 ]
------------------------------
Epoch: 20
Training loss: 0.18279500305652618 | Validation loss: 0.17318619787693024
Validation loss (ends of cycles): [0.20745772 0.1891827 0.1731862 ]
4. Evaluate
# Replace following Paths with yours
= Path('test/dumps-test/cnn/train_eval/all/models')
src_dir_model = Learners(Model, tax_lookup, seeds=seeds, device=device)
learners = learners.evaluate((X, y, depth_order[:, -1]),
perfs_global_all, y_hats_all, y_trues_all, ns_all =src_dir_model) src_dir_model
-1]), src_dir_model=src_dir_model) learners.evaluate((X, y, depth_order[:,
(Empty DataFrame
Columns: []
Index: [],
Empty DataFrame
Columns: []
Index: [],
Empty DataFrame
Columns: []
Index: [],
Empty DataFrame
Columns: []
Index: [])
5. Learning rate finder
= 0.1
split_ratio
# Train/test split
= train_test_split(X,
X_train, X_test, y_train, y_test, tax_order_train, tax_order_test
y, 1],
depth_order[:,=split_ratio,
test_size=42)
random_state
# Further train/valid split
= train_test_split(X_train,
X_train, X_valid, y_train, y_valid, tax_order_train, tax_order_valid
y_train,
tax_order_train, =split_ratio,
test_size=42)
random_state
= DataLoaders((X_train, y_train, tax_order_train),
dls
(X_valid, y_valid, tax_order_valid),
(X_test, y_test, tax_order_test), =SNV_transform())
transform
= dls.loaders() training_generator, validation_generator, test_generator
= 2
n_epochs = 5
step_size_up = MSELoss() # Mean Squared Error loss
criterion = 3e-5, 1e-3 # Based on learning rate finder base_lr, max_lr
## LR finder
= Model(X.shape[1], out_channel=16).to(device)
model
= Adam(model.parameters(), lr=1e-4)
opt = model.apply(weights_init)
model
= CyclicLR(opt, base_lr=base_lr, max_lr=max_lr,
scheduler =step_size_up, mode='triangular',
step_size_up=False)
cycle_momentum
= Learner(model, criterion, opt, n_epochs=n_epochs,
learner =scheduler, early_stopper=None,
scheduler=tax_lookup.values(), verbose=True)
tax_lookup
= learner.lr_finder(training_generator, end=0.1, n_epochs=8) lrs, losses
Epoch: 0
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
'learning_rate': lrs, 'loss': losses}).plot(x='learning_rate', y='loss', logx=True); pd.DataFrame({