# setting up pod and pip install uhina
# accessing a pod terminal
# 1. To get access to the pod ip adress: runpodctl get pod -a
# 2. ssh into the pod: ssh root@<ip-address> -p 58871 -i ~/.ssh/id_ed25519
# git clone https://github.com/franckalbinet/uhina.git
# pip install uhina
# runpodctl send im-bw
# runpodctl send ossl-tfm.csv
Ringtrial
Evaluation Resnet 18 pre-trained on OSSL dataset on ringtrial data.
To do:
- for each lab, fine-tune and eval on test set
- use TTA and at what n it starts leveling off?
- increase progressively train, valid, test set sizes
- normalize on ringtrial statistics
Runpod setup
Loading data
import pandas as pd
from pathlib import Path
import fastcore.all as fc
from fastai.data.all import *
from fastai.vision.all import *
from multiprocessing import cpu_count
from sklearn.metrics import r2_score
from uhina.augment import Quantize
import warnings
'ignore')
warnings.filterwarnings(
'display.max_rows', 100) pd.set_option(
= '../../_data/ringtrial-tfm/im-targets-lut.csv'
src = pd.read_csv(src)
df 'lab'] = df['fname'].str.split('-rt', n=1).str[0]
df[ df.head()
fname | potassium_cmolkg | lab | |
---|---|---|---|
0 | agrocares-rt-01.png | 0.26906 | agrocares |
1 | agrocares-rt-02.png | 0.23349 | agrocares |
2 | agrocares-rt-03.png | 0.29109 | agrocares |
3 | agrocares-rt-04.png | 0.49925 | agrocares |
4 | agrocares-rt-05.png | 0.59977 | agrocares |
'potassium_cmolkg'] = df['potassium_cmolkg'].apply(np.log1p) df[
Fine-tuning on ringtrial
class OrderedQuantize(Quantize):
= 0 # Apply first
order
class OrderedRatioResize(RatioResize):
= 1 # Apply second order
# learn = load_learner('./models/650-4000-epoch-25-lr-3e-3.pkl', cpu=True)
# learn = load_learner('./models/unfrozen-epoch-30-lr-1.5e-3-12102024.pkl', cpu=True)
# learn = load_learner('./models/unfrozen-epoch-30-lr-1.5e-3-12102024.pkl', cpu=True)
= load_learner('./models/frozen-epoch-30-lr-1.5e-3-12102024.pkl', cpu=True) learn
df.lab.unique()
array(['agrocares', 'argonne', 'csu-il', 'eth-alpha-1', 'eth-alpha-2',
'eth-vertex', 'iaea-aug2022', 'kssl', 'landcare', 'lesotho', 'msu',
'osu', 'rothamsted', 'scion', 'ughent', 'uiuc', 'usp',
'uwisc-fine', 'woodwell-alpha', 'woodwell-vertex'], dtype=object)
# np.expm1(np.log1p(2))
= df[df.lab == 'kssl']
df_selected df_selected.head()
fname | potassium_cmolkg | lab | |
---|---|---|---|
483 | kssl-rt-01.png | 0.238276 | kssl |
484 | kssl-rt-02.png | 0.209848 | kssl |
485 | kssl-rt-03.png | 0.255487 | kssl |
486 | kssl-rt-04.png | 0.404965 | kssl |
487 | kssl-rt-05.png | 0.469860 | kssl |
# def splitter(items): return [idx_train, idx_valid]
= False
eval_on_pretrained if eval_on_pretrained:
= DataBlock(
dblock =(ImageBlock, RegressionBlock),
blocks=ColReader(0, pref='../../_data/ringtrial-tfm/im/'),
get_x=ColReader(1),
get_y# splitter=splitter,
=RandomSplitter(valid_pct=0, seed=41),
splitter=[OrderedQuantize(n_valid=len(df_selected))],
item_tfms=[
batch_tfms224),
OrderedRatioResize(*imagenet_stats)
Normalize.from_stats(
]
)= dblock.dataloaders(df_selected, bs=len(df_selected))
dls = learn.get_preds(dl=dls.train)
val_preds, val_targets = r2_score(val_targets, val_preds)
r2 print(r2)
# Eval on pre-trained model
# eval_on_pretrained = True
# if eval_on_pretrained:
# dblock = DataBlock(
# blocks=(ImageBlock, RegressionBlock),
# get_x=ColReader(0, pref='../../_data/ringtrial-tfm/im/'),
# get_y=ColReader(1),
# splitter=RandomSplitter(valid_pct=0, seed=41),
# batch_tfms=[RatioResize(224)],
# item_tfms=[Quantize(n_valid=len(df_selected))])
# dls = dblock.dataloaders(df_selected, bs=len(df_selected))
# val_preds, val_targets = learn.get_preds(dl=dls.train)
# r2 = r2_score(val_targets, val_preds)
# print(r2)
'potassium_cmolkg'] = df_selected['potassium_cmolkg'].apply(np.log1p)
df_selected.loc[:, df_selected.head()
fname | potassium_cmolkg | lab | |
---|---|---|---|
483 | kssl-rt-01.png | 0.213720 | kssl |
484 | kssl-rt-02.png | 0.190494 | kssl |
485 | kssl-rt-03.png | 0.227523 | kssl |
486 | kssl-rt-04.png | 0.340012 | kssl |
487 | kssl-rt-05.png | 0.385167 | kssl |
df_selected
fname | potassium_cmolkg | lab | |
---|---|---|---|
483 | kssl-rt-01.png | 0.213720 | kssl |
484 | kssl-rt-02.png | 0.190494 | kssl |
485 | kssl-rt-03.png | 0.227523 | kssl |
486 | kssl-rt-04.png | 0.340012 | kssl |
487 | kssl-rt-05.png | 0.385167 | kssl |
488 | kssl-rt-06.png | 0.402441 | kssl |
489 | kssl-rt-07.png | 0.331974 | kssl |
490 | kssl-rt-08.png | 0.101317 | kssl |
491 | kssl-rt-09.png | 0.681530 | kssl |
492 | kssl-rt-10.png | 0.274231 | kssl |
493 | kssl-rt-11.png | 0.289119 | kssl |
494 | kssl-rt-12.png | 0.256386 | kssl |
495 | kssl-rt-13.png | 0.369230 | kssl |
496 | kssl-rt-14.png | 0.191284 | kssl |
497 | kssl-rt-15.png | 0.393472 | kssl |
498 | kssl-rt-16.png | 0.508054 | kssl |
499 | kssl-rt-17.png | 0.467364 | kssl |
500 | kssl-rt-18.png | 0.308044 | kssl |
501 | kssl-rt-19.png | 0.292841 | kssl |
502 | kssl-rt-20.png | 0.294769 | kssl |
503 | kssl-rt-21.png | 0.335060 | kssl |
504 | kssl-rt-22.png | 0.079189 | kssl |
505 | kssl-rt-23.png | 0.272119 | kssl |
506 | kssl-rt-24.png | 0.160069 | kssl |
507 | kssl-rt-25.png | 0.158562 | kssl |
508 | kssl-rt-26.png | 0.144258 | kssl |
509 | kssl-rt-27.png | 0.137756 | kssl |
510 | kssl-rt-28.png | 0.516908 | kssl |
511 | kssl-rt-29.png | 0.267374 | kssl |
512 | kssl-rt-30.png | 0.301326 | kssl |
513 | kssl-rt-31.png | 0.306298 | kssl |
514 | kssl-rt-32.png | 0.291591 | kssl |
515 | kssl-rt-33.png | 0.373274 | kssl |
516 | kssl-rt-34.png | 0.100724 | kssl |
517 | kssl-rt-35.png | 0.082217 | kssl |
518 | kssl-rt-36.png | 0.142215 | kssl |
519 | kssl-rt-37.png | 0.126304 | kssl |
520 | kssl-rt-38.png | 0.045243 | kssl |
521 | kssl-rt-39.png | 0.334103 | kssl |
522 | kssl-rt-40.png | 0.515146 | kssl |
523 | kssl-rt-41.png | 0.531581 | kssl |
524 | kssl-rt-42.png | 0.351257 | kssl |
525 | kssl-rt-43.png | 0.233260 | kssl |
526 | kssl-rt-45.png | 0.271288 | kssl |
527 | kssl-rt-46.png | 0.307444 | kssl |
528 | kssl-rt-47.png | 0.256559 | kssl |
529 | kssl-rt-48.png | 0.544772 | kssl |
530 | kssl-rt-49.png | 0.366122 | kssl |
531 | kssl-rt-50.png | 0.542601 | kssl |
532 | kssl-rt-51.png | 0.547084 | kssl |
533 | kssl-rt-52.png | 0.283886 | kssl |
534 | kssl-rt-53.png | 0.337035 | kssl |
535 | kssl-rt-54.png | 0.284184 | kssl |
536 | kssl-rt-55.png | 0.306663 | kssl |
537 | kssl-rt-56.png | 0.541332 | kssl |
538 | kssl-rt-57.png | 0.548911 | kssl |
539 | kssl-rt-58.png | 0.551454 | kssl |
540 | kssl-rt-59.png | 0.380625 | kssl |
541 | kssl-rt-60.png | 0.462484 | kssl |
542 | kssl-rt-61.png | 0.379143 | kssl |
543 | kssl-rt-62.png | 0.223065 | kssl |
544 | kssl-rt-63.png | 0.223065 | kssl |
545 | kssl-rt-64.png | 0.131506 | kssl |
546 | kssl-rt-65.png | 0.226245 | kssl |
547 | kssl-rt-66.png | 0.852376 | kssl |
548 | kssl-rt-67.png | 0.619430 | kssl |
549 | kssl-rt-68.png | 0.486798 | kssl |
550 | kssl-rt-69.png | 0.519743 | kssl |
551 | kssl-rt-70.png | 0.551429 | kssl |
'potassium_cmolkg'].hist() df_selected[
Train/valid/test split
Using Kennard-Stone
df.lab.unique()
array(['agrocares', 'argonne', 'csu-il', 'eth-alpha-1', 'eth-alpha-2',
'eth-vertex', 'iaea-aug2022', 'kssl', 'landcare', 'lesotho', 'msu',
'osu', 'rothamsted', 'scion', 'ughent', 'uiuc', 'usp',
'uwisc-fine', 'woodwell-alpha', 'woodwell-vertex'], dtype=object)
= df[df.lab == 'kssl'] df_selected
len(df_selected)
69
from uhina.loading import LoaderFactory
= Path.home() / 'pro/data/woodwell-ringtrial/drive-download-20231013T123706Z-001'
src = LoaderFactory.get_loader(src, 'ringtrial')
loader = loader.load_data(analytes='potassium_cmolkg')
data print(f'X shape: {data.X.shape}')
X shape: (1400, 1676)
# 44 is missing
= [name.split('.png')[0] for name in df_selected.fname]; indices_df indices_df
['kssl-rt-01',
'kssl-rt-02',
'kssl-rt-03',
'kssl-rt-04',
'kssl-rt-05',
'kssl-rt-06',
'kssl-rt-07',
'kssl-rt-08',
'kssl-rt-09',
'kssl-rt-10',
'kssl-rt-11',
'kssl-rt-12',
'kssl-rt-13',
'kssl-rt-14',
'kssl-rt-15',
'kssl-rt-16',
'kssl-rt-17',
'kssl-rt-18',
'kssl-rt-19',
'kssl-rt-20',
'kssl-rt-21',
'kssl-rt-22',
'kssl-rt-23',
'kssl-rt-24',
'kssl-rt-25',
'kssl-rt-26',
'kssl-rt-27',
'kssl-rt-28',
'kssl-rt-29',
'kssl-rt-30',
'kssl-rt-31',
'kssl-rt-32',
'kssl-rt-33',
'kssl-rt-34',
'kssl-rt-35',
'kssl-rt-36',
'kssl-rt-37',
'kssl-rt-38',
'kssl-rt-39',
'kssl-rt-40',
'kssl-rt-41',
'kssl-rt-42',
'kssl-rt-43',
'kssl-rt-45',
'kssl-rt-46',
'kssl-rt-47',
'kssl-rt-48',
'kssl-rt-49',
'kssl-rt-50',
'kssl-rt-51',
'kssl-rt-52',
'kssl-rt-53',
'kssl-rt-54',
'kssl-rt-55',
'kssl-rt-56',
'kssl-rt-57',
'kssl-rt-58',
'kssl-rt-59',
'kssl-rt-60',
'kssl-rt-61',
'kssl-rt-62',
'kssl-rt-63',
'kssl-rt-64',
'kssl-rt-65',
'kssl-rt-66',
'kssl-rt-67',
'kssl-rt-68',
'kssl-rt-69',
'kssl-rt-70']
= np.isin(data.sample_indices, np.array(indices_df)) mask
data.sample_indices[mask]
array(['kssl-rt-01', 'kssl-rt-02', 'kssl-rt-03', 'kssl-rt-04',
'kssl-rt-05', 'kssl-rt-06', 'kssl-rt-07', 'kssl-rt-08',
'kssl-rt-09', 'kssl-rt-10', 'kssl-rt-11', 'kssl-rt-12',
'kssl-rt-13', 'kssl-rt-14', 'kssl-rt-15', 'kssl-rt-16',
'kssl-rt-17', 'kssl-rt-18', 'kssl-rt-19', 'kssl-rt-20',
'kssl-rt-21', 'kssl-rt-22', 'kssl-rt-23', 'kssl-rt-24',
'kssl-rt-25', 'kssl-rt-26', 'kssl-rt-27', 'kssl-rt-28',
'kssl-rt-29', 'kssl-rt-30', 'kssl-rt-31', 'kssl-rt-32',
'kssl-rt-33', 'kssl-rt-34', 'kssl-rt-35', 'kssl-rt-36',
'kssl-rt-37', 'kssl-rt-38', 'kssl-rt-39', 'kssl-rt-40',
'kssl-rt-41', 'kssl-rt-42', 'kssl-rt-43', 'kssl-rt-45',
'kssl-rt-46', 'kssl-rt-47', 'kssl-rt-48', 'kssl-rt-49',
'kssl-rt-50', 'kssl-rt-51', 'kssl-rt-52', 'kssl-rt-53',
'kssl-rt-54', 'kssl-rt-55', 'kssl-rt-56', 'kssl-rt-57',
'kssl-rt-58', 'kssl-rt-59', 'kssl-rt-60', 'kssl-rt-61',
'kssl-rt-62', 'kssl-rt-63', 'kssl-rt-64', 'kssl-rt-65',
'kssl-rt-66', 'kssl-rt-67', 'kssl-rt-68', 'kssl-rt-69',
'kssl-rt-70'], dtype=object)
data.sample_indices[mask]
array(['kssl-rt-01', 'kssl-rt-02', 'kssl-rt-03', 'kssl-rt-04',
'kssl-rt-05', 'kssl-rt-06', 'kssl-rt-07', 'kssl-rt-08',
'kssl-rt-09', 'kssl-rt-10', 'kssl-rt-11', 'kssl-rt-12',
'kssl-rt-13', 'kssl-rt-14', 'kssl-rt-15', 'kssl-rt-16',
'kssl-rt-17', 'kssl-rt-18', 'kssl-rt-19', 'kssl-rt-20',
'kssl-rt-21', 'kssl-rt-22', 'kssl-rt-23', 'kssl-rt-24',
'kssl-rt-25', 'kssl-rt-26', 'kssl-rt-27', 'kssl-rt-28',
'kssl-rt-29', 'kssl-rt-30', 'kssl-rt-31', 'kssl-rt-32',
'kssl-rt-33', 'kssl-rt-34', 'kssl-rt-35', 'kssl-rt-36',
'kssl-rt-37', 'kssl-rt-38', 'kssl-rt-39', 'kssl-rt-40',
'kssl-rt-41', 'kssl-rt-42', 'kssl-rt-43', 'kssl-rt-45',
'kssl-rt-46', 'kssl-rt-47', 'kssl-rt-48', 'kssl-rt-49',
'kssl-rt-50', 'kssl-rt-51', 'kssl-rt-52', 'kssl-rt-53',
'kssl-rt-54', 'kssl-rt-55', 'kssl-rt-56', 'kssl-rt-57',
'kssl-rt-58', 'kssl-rt-59', 'kssl-rt-60', 'kssl-rt-61',
'kssl-rt-62', 'kssl-rt-63', 'kssl-rt-64', 'kssl-rt-65',
'kssl-rt-66', 'kssl-rt-67', 'kssl-rt-68', 'kssl-rt-69',
'kssl-rt-70'], dtype=object)
=True, drop=True) df_selected.reset_index(inplace
data.sample_indices[mask]
array(['kssl-rt-01', 'kssl-rt-02', 'kssl-rt-03', 'kssl-rt-04',
'kssl-rt-05', 'kssl-rt-06', 'kssl-rt-07', 'kssl-rt-08',
'kssl-rt-09', 'kssl-rt-10', 'kssl-rt-11', 'kssl-rt-12',
'kssl-rt-13', 'kssl-rt-14', 'kssl-rt-15', 'kssl-rt-16',
'kssl-rt-17', 'kssl-rt-18', 'kssl-rt-19', 'kssl-rt-20',
'kssl-rt-21', 'kssl-rt-22', 'kssl-rt-23', 'kssl-rt-24',
'kssl-rt-25', 'kssl-rt-26', 'kssl-rt-27', 'kssl-rt-28',
'kssl-rt-29', 'kssl-rt-30', 'kssl-rt-31', 'kssl-rt-32',
'kssl-rt-33', 'kssl-rt-34', 'kssl-rt-35', 'kssl-rt-36',
'kssl-rt-37', 'kssl-rt-38', 'kssl-rt-39', 'kssl-rt-40',
'kssl-rt-41', 'kssl-rt-42', 'kssl-rt-43', 'kssl-rt-45',
'kssl-rt-46', 'kssl-rt-47', 'kssl-rt-48', 'kssl-rt-49',
'kssl-rt-50', 'kssl-rt-51', 'kssl-rt-52', 'kssl-rt-53',
'kssl-rt-54', 'kssl-rt-55', 'kssl-rt-56', 'kssl-rt-57',
'kssl-rt-58', 'kssl-rt-59', 'kssl-rt-60', 'kssl-rt-61',
'kssl-rt-62', 'kssl-rt-63', 'kssl-rt-64', 'kssl-rt-65',
'kssl-rt-66', 'kssl-rt-67', 'kssl-rt-68', 'kssl-rt-69',
'kssl-rt-70'], dtype=object)
# mask = np.char.find(data.sample_indices.astype(str), 'kssl') != -1
= data.X[mask], np.log1p(data.y[mask]) X_lab, y_lab
data.sample_indices[mask]
array(['kssl-rt-01', 'kssl-rt-02', 'kssl-rt-03', 'kssl-rt-04',
'kssl-rt-05', 'kssl-rt-06', 'kssl-rt-07', 'kssl-rt-08',
'kssl-rt-09', 'kssl-rt-10', 'kssl-rt-11', 'kssl-rt-12',
'kssl-rt-13', 'kssl-rt-14', 'kssl-rt-15', 'kssl-rt-16',
'kssl-rt-17', 'kssl-rt-18', 'kssl-rt-19', 'kssl-rt-20',
'kssl-rt-21', 'kssl-rt-22', 'kssl-rt-23', 'kssl-rt-24',
'kssl-rt-25', 'kssl-rt-26', 'kssl-rt-27', 'kssl-rt-28',
'kssl-rt-29', 'kssl-rt-30', 'kssl-rt-31', 'kssl-rt-32',
'kssl-rt-33', 'kssl-rt-34', 'kssl-rt-35', 'kssl-rt-36',
'kssl-rt-37', 'kssl-rt-38', 'kssl-rt-39', 'kssl-rt-40',
'kssl-rt-41', 'kssl-rt-42', 'kssl-rt-43', 'kssl-rt-45',
'kssl-rt-46', 'kssl-rt-47', 'kssl-rt-48', 'kssl-rt-49',
'kssl-rt-50', 'kssl-rt-51', 'kssl-rt-52', 'kssl-rt-53',
'kssl-rt-54', 'kssl-rt-55', 'kssl-rt-56', 'kssl-rt-57',
'kssl-rt-58', 'kssl-rt-59', 'kssl-rt-60', 'kssl-rt-61',
'kssl-rt-62', 'kssl-rt-63', 'kssl-rt-64', 'kssl-rt-65',
'kssl-rt-66', 'kssl-rt-67', 'kssl-rt-68', 'kssl-rt-69',
'kssl-rt-70'], dtype=object)
df_selected
fname | potassium_cmolkg | lab | |
---|---|---|---|
0 | kssl-rt-01.png | 0.238276 | kssl |
1 | kssl-rt-02.png | 0.209848 | kssl |
2 | kssl-rt-03.png | 0.255487 | kssl |
3 | kssl-rt-04.png | 0.404965 | kssl |
4 | kssl-rt-05.png | 0.469860 | kssl |
5 | kssl-rt-06.png | 0.495470 | kssl |
6 | kssl-rt-07.png | 0.393716 | kssl |
7 | kssl-rt-08.png | 0.106628 | kssl |
8 | kssl-rt-09.png | 0.976900 | kssl |
9 | kssl-rt-10.png | 0.315519 | kssl |
10 | kssl-rt-11.png | 0.335250 | kssl |
11 | kssl-rt-12.png | 0.292252 | kssl |
12 | kssl-rt-13.png | 0.446620 | kssl |
13 | kssl-rt-14.png | 0.210804 | kssl |
14 | kssl-rt-15.png | 0.482117 | kssl |
15 | kssl-rt-16.png | 0.662054 | kssl |
16 | kssl-rt-17.png | 0.595782 | kssl |
17 | kssl-rt-18.png | 0.360761 | kssl |
18 | kssl-rt-19.png | 0.340229 | kssl |
19 | kssl-rt-20.png | 0.342816 | kssl |
20 | kssl-rt-21.png | 0.398024 | kssl |
21 | kssl-rt-22.png | 0.082409 | kssl |
22 | kssl-rt-23.png | 0.312743 | kssl |
23 | kssl-rt-24.png | 0.173592 | kssl |
24 | kssl-rt-25.png | 0.171825 | kssl |
25 | kssl-rt-26.png | 0.155182 | kssl |
26 | kssl-rt-27.png | 0.147696 | kssl |
27 | kssl-rt-28.png | 0.676835 | kssl |
28 | kssl-rt-29.png | 0.306528 | kssl |
29 | kssl-rt-30.png | 0.351649 | kssl |
30 | kssl-rt-31.png | 0.358387 | kssl |
31 | kssl-rt-32.png | 0.338556 | kssl |
32 | kssl-rt-33.png | 0.452482 | kssl |
33 | kssl-rt-34.png | 0.105971 | kssl |
34 | kssl-rt-35.png | 0.085691 | kssl |
35 | kssl-rt-36.png | 0.152824 | kssl |
36 | kssl-rt-37.png | 0.134627 | kssl |
37 | kssl-rt-38.png | 0.046282 | kssl |
38 | kssl-rt-39.png | 0.396687 | kssl |
39 | kssl-rt-40.png | 0.673883 | kssl |
40 | kssl-rt-41.png | 0.701621 | kssl |
41 | kssl-rt-42.png | 0.420853 | kssl |
42 | kssl-rt-43.png | 0.262710 | kssl |
43 | kssl-rt-45.png | 0.311652 | kssl |
44 | kssl-rt-46.png | 0.359945 | kssl |
45 | kssl-rt-47.png | 0.292476 | kssl |
46 | kssl-rt-48.png | 0.724215 | kssl |
47 | kssl-rt-49.png | 0.442131 | kssl |
48 | kssl-rt-50.png | 0.720475 | kssl |
49 | kssl-rt-51.png | 0.728205 | kssl |
50 | kssl-rt-52.png | 0.328282 | kssl |
51 | kssl-rt-53.png | 0.400788 | kssl |
52 | kssl-rt-54.png | 0.328678 | kssl |
53 | kssl-rt-55.png | 0.358884 | kssl |
54 | kssl-rt-56.png | 0.718293 | kssl |
55 | kssl-rt-57.png | 0.731367 | kssl |
56 | kssl-rt-58.png | 0.735776 | kssl |
57 | kssl-rt-59.png | 0.463199 | kssl |
58 | kssl-rt-60.png | 0.588014 | kssl |
59 | kssl-rt-61.png | 0.461032 | kssl |
60 | kssl-rt-62.png | 0.249902 | kssl |
61 | kssl-rt-63.png | 0.249902 | kssl |
62 | kssl-rt-64.png | 0.140544 | kssl |
63 | kssl-rt-65.png | 0.253882 | kssl |
64 | kssl-rt-66.png | 1.345212 | kssl |
65 | kssl-rt-67.png | 0.857869 | kssl |
66 | kssl-rt-68.png | 0.627098 | kssl |
67 | kssl-rt-69.png | 0.681596 | kssl |
68 | kssl-rt-70.png | 0.735732 | kssl |
X_lab
array([[1.2708 , 1.26602, 1.26191, ..., 0.15597, 0.15574, 0.15549],
[1.68078, 1.69329, 1.70438, ..., 0.22922, 0.22891, 0.22859],
[1.69767, 1.69935, 1.70112, ..., 0.38133, 0.38056, 0.3798 ],
...,
[1.65483, 1.65777, 1.6626 , ..., 0.22134, 0.22078, 0.2202 ],
[1.86684, 1.86213, 1.85727, ..., 0.14837, 0.14783, 0.14725],
[1.62302, 1.62296, 1.62328, ..., 0.2393 , 0.23909, 0.23888]])
from uhina.preprocessing import SNV, TakeDerivative
from sklearn.pipeline import Pipeline
= Pipeline([
pipe 'SNV', SNV()),
('Derivative', TakeDerivative())
(
])
= pipe.fit_transform(X_lab) X_lab_trans
df_selected
fname | potassium_cmolkg | lab | |
---|---|---|---|
0 | kssl-rt-01.png | 0.238276 | kssl |
1 | kssl-rt-02.png | 0.209848 | kssl |
2 | kssl-rt-03.png | 0.255487 | kssl |
3 | kssl-rt-04.png | 0.404965 | kssl |
4 | kssl-rt-05.png | 0.469860 | kssl |
5 | kssl-rt-06.png | 0.495470 | kssl |
6 | kssl-rt-07.png | 0.393716 | kssl |
7 | kssl-rt-08.png | 0.106628 | kssl |
8 | kssl-rt-09.png | 0.976900 | kssl |
9 | kssl-rt-10.png | 0.315519 | kssl |
10 | kssl-rt-11.png | 0.335250 | kssl |
11 | kssl-rt-12.png | 0.292252 | kssl |
12 | kssl-rt-13.png | 0.446620 | kssl |
13 | kssl-rt-14.png | 0.210804 | kssl |
14 | kssl-rt-15.png | 0.482117 | kssl |
15 | kssl-rt-16.png | 0.662054 | kssl |
16 | kssl-rt-17.png | 0.595782 | kssl |
17 | kssl-rt-18.png | 0.360761 | kssl |
18 | kssl-rt-19.png | 0.340229 | kssl |
19 | kssl-rt-20.png | 0.342816 | kssl |
20 | kssl-rt-21.png | 0.398024 | kssl |
21 | kssl-rt-22.png | 0.082409 | kssl |
22 | kssl-rt-23.png | 0.312743 | kssl |
23 | kssl-rt-24.png | 0.173592 | kssl |
24 | kssl-rt-25.png | 0.171825 | kssl |
25 | kssl-rt-26.png | 0.155182 | kssl |
26 | kssl-rt-27.png | 0.147696 | kssl |
27 | kssl-rt-28.png | 0.676835 | kssl |
28 | kssl-rt-29.png | 0.306528 | kssl |
29 | kssl-rt-30.png | 0.351649 | kssl |
30 | kssl-rt-31.png | 0.358387 | kssl |
31 | kssl-rt-32.png | 0.338556 | kssl |
32 | kssl-rt-33.png | 0.452482 | kssl |
33 | kssl-rt-34.png | 0.105971 | kssl |
34 | kssl-rt-35.png | 0.085691 | kssl |
35 | kssl-rt-36.png | 0.152824 | kssl |
36 | kssl-rt-37.png | 0.134627 | kssl |
37 | kssl-rt-38.png | 0.046282 | kssl |
38 | kssl-rt-39.png | 0.396687 | kssl |
39 | kssl-rt-40.png | 0.673883 | kssl |
40 | kssl-rt-41.png | 0.701621 | kssl |
41 | kssl-rt-42.png | 0.420853 | kssl |
42 | kssl-rt-43.png | 0.262710 | kssl |
43 | kssl-rt-45.png | 0.311652 | kssl |
44 | kssl-rt-46.png | 0.359945 | kssl |
45 | kssl-rt-47.png | 0.292476 | kssl |
46 | kssl-rt-48.png | 0.724215 | kssl |
47 | kssl-rt-49.png | 0.442131 | kssl |
48 | kssl-rt-50.png | 0.720475 | kssl |
49 | kssl-rt-51.png | 0.728205 | kssl |
50 | kssl-rt-52.png | 0.328282 | kssl |
51 | kssl-rt-53.png | 0.400788 | kssl |
52 | kssl-rt-54.png | 0.328678 | kssl |
53 | kssl-rt-55.png | 0.358884 | kssl |
54 | kssl-rt-56.png | 0.718293 | kssl |
55 | kssl-rt-57.png | 0.731367 | kssl |
56 | kssl-rt-58.png | 0.735776 | kssl |
57 | kssl-rt-59.png | 0.463199 | kssl |
58 | kssl-rt-60.png | 0.588014 | kssl |
59 | kssl-rt-61.png | 0.461032 | kssl |
60 | kssl-rt-62.png | 0.249902 | kssl |
61 | kssl-rt-63.png | 0.249902 | kssl |
62 | kssl-rt-64.png | 0.140544 | kssl |
63 | kssl-rt-65.png | 0.253882 | kssl |
64 | kssl-rt-66.png | 1.345212 | kssl |
65 | kssl-rt-67.png | 0.857869 | kssl |
66 | kssl-rt-68.png | 0.627098 | kssl |
67 | kssl-rt-69.png | 0.681596 | kssl |
68 | kssl-rt-70.png | 0.735732 | kssl |
X_lab_trans.shape
(69, 1676)
import kennard_stone as ks
# train_idx, valid_idx, X_train, X_valid = ks.train_test_split(np.array(range(len(X_lab_trans))).reshape(-1, 1),
# X_lab_trans, test_size = 0.2)
# train_idx = train_idx.ravel()
# valid_idx = valid_idx.ravel()
= ks.train_test_split(X_lab_trans,
X_train, X_valid, train_idx, valid_idx range(len(X_lab_trans)),
= 0.2) test_size
Calculating pairwise distances using scikit-learn.
Calculating pairwise distances using scikit-learn.
valid_idx
[35, 6, 24, 12, 3, 31, 30, 5, 2, 25, 49, 28, 44, 40]
'potassium_cmolkg'].hist() df_selected.loc[train_idx, :][
'potassium_cmolkg'].hist() df_selected.loc[valid_idx, :][
from sklearn.cross_decomposition import PLSRegression
= []
scores for n in range(1,20):
= PLSRegression(n_components=n)
pls
pls.fit(X_lab_trans[train_idx], y_lab[train_idx])= pls.predict(X_lab_trans[valid_idx])
y_predicted print(n, r2_score(y_predicted, y_lab[valid_idx]))
scores.append(r2_score(y_predicted, y_lab[valid_idx]))
range(1, 20), scores) plt.plot(
1 0.9167139331659132
2 0.8169940843443368
3 0.8201480463134329
4 0.8330463354402591
5 0.8336164418584839
6 0.8066819139016866
7 0.8063689082131712
8 0.8476483270390328
9 0.8442719648387733
10 0.8550880328401899
11 0.85550155049048
12 0.8773457853046793
13 0.8640833529184554
14 0.8467865332449903
15 0.8598890265246808
16 0.860227416062531
17 0.8890976617393181
18 0.8975066928044428
19 0.9022963785020354
= PLSRegression(n_components=1)
pls
pls.fit(X_lab_trans[train_idx], np.log1p(data.y[mask][train_idx]))= pls.predict(X_lab_trans[valid_idx]) y_predicted
= y_predicted, np.log1p(data.y[mask][valid_idx])
x, y '.')
plt.plot(x, y, # Add the diagonal line
= min(y.min(), x.min())
min_val = max(y.max(), x.max())
max_val 'k--', lw=1) plt.plot([min_val, max_val], [min_val, max_val],
# dblock = DataBlock(blocks=(ImageBlock, RegressionBlock),
# get_x=ColReader(0, pref='../../_data/ringtrial-tfm/im/'),
# get_y=ColReader(1),
# splitter=RandomSplitter(valid_pct=0, seed=41),
# batch_tfms=[RatioResize(224)],
# item_tfms=[Quantize()])
# class ModelEvaluator:
# def __init__(self, model_path, dblock):
# self.learn = load_learner(model_path, cpu=True)
# self.dblock = dblock
# def evaluate(self, df_selected, batch_size=16, use_tta=False, tta_n=4):
# dls = self.dblock.dataloaders(df_selected, bs=batch_size)
# if use_tta:
# val_preds, val_targets = self.learn.tta(dl=dls.train, n=tta_n)
# else:
# val_preds, val_targets = self.learn.get_preds(dl=dls.train)
# r2 = r2_score(val_targets, val_preds)
# return val_preds, val_targets, r2
# model_path = './models/650-4000-epoch-25-lr-3e-3.pkl'
# evaluator = ModelEvaluator(model_path, dblock)
len(train_idx), len(valid_idx)
(55, 14)
def has_common_elements(list1, list2): return bool(set(list1) & set(list2))
False) fc.test_eq(has_common_elements(train_idx, valid_idx),
df_selected
fname | potassium_cmolkg | lab | |
---|---|---|---|
0 | eth-alpha-1-rt-01.png | 0.238276 | eth-alpha-1 |
1 | eth-alpha-1-rt-02.png | 0.209848 | eth-alpha-1 |
2 | eth-alpha-1-rt-03.png | 0.255487 | eth-alpha-1 |
3 | eth-alpha-1-rt-04.png | 0.404965 | eth-alpha-1 |
4 | eth-alpha-1-rt-05.png | 0.469860 | eth-alpha-1 |
5 | eth-alpha-1-rt-06.png | 0.495470 | eth-alpha-1 |
6 | eth-alpha-1-rt-07.png | 0.393716 | eth-alpha-1 |
7 | eth-alpha-1-rt-08.png | 0.106628 | eth-alpha-1 |
8 | eth-alpha-1-rt-09.png | 0.976900 | eth-alpha-1 |
9 | eth-alpha-1-rt-10.png | 0.315519 | eth-alpha-1 |
10 | eth-alpha-1-rt-11.png | 0.335250 | eth-alpha-1 |
11 | eth-alpha-1-rt-12.png | 0.292252 | eth-alpha-1 |
12 | eth-alpha-1-rt-13.png | 0.446620 | eth-alpha-1 |
13 | eth-alpha-1-rt-14.png | 0.210804 | eth-alpha-1 |
14 | eth-alpha-1-rt-15.png | 0.482117 | eth-alpha-1 |
15 | eth-alpha-1-rt-16.png | 0.662054 | eth-alpha-1 |
16 | eth-alpha-1-rt-17.png | 0.595782 | eth-alpha-1 |
17 | eth-alpha-1-rt-18.png | 0.360761 | eth-alpha-1 |
18 | eth-alpha-1-rt-19.png | 0.340229 | eth-alpha-1 |
19 | eth-alpha-1-rt-20.png | 0.342816 | eth-alpha-1 |
20 | eth-alpha-1-rt-21.png | 0.398024 | eth-alpha-1 |
21 | eth-alpha-1-rt-22.png | 0.082409 | eth-alpha-1 |
22 | eth-alpha-1-rt-23.png | 0.312743 | eth-alpha-1 |
23 | eth-alpha-1-rt-24.png | 0.173592 | eth-alpha-1 |
24 | eth-alpha-1-rt-25.png | 0.171825 | eth-alpha-1 |
25 | eth-alpha-1-rt-26.png | 0.155182 | eth-alpha-1 |
26 | eth-alpha-1-rt-27.png | 0.147696 | eth-alpha-1 |
27 | eth-alpha-1-rt-28.png | 0.676835 | eth-alpha-1 |
28 | eth-alpha-1-rt-29.png | 0.306528 | eth-alpha-1 |
29 | eth-alpha-1-rt-30.png | 0.351649 | eth-alpha-1 |
30 | eth-alpha-1-rt-31.png | 0.358387 | eth-alpha-1 |
31 | eth-alpha-1-rt-32.png | 0.338556 | eth-alpha-1 |
32 | eth-alpha-1-rt-33.png | 0.452482 | eth-alpha-1 |
33 | eth-alpha-1-rt-34.png | 0.105971 | eth-alpha-1 |
34 | eth-alpha-1-rt-35.png | 0.085691 | eth-alpha-1 |
35 | eth-alpha-1-rt-36.png | 0.152824 | eth-alpha-1 |
36 | eth-alpha-1-rt-37.png | 0.134627 | eth-alpha-1 |
37 | eth-alpha-1-rt-38.png | 0.046282 | eth-alpha-1 |
38 | eth-alpha-1-rt-39.png | 0.396687 | eth-alpha-1 |
39 | eth-alpha-1-rt-40.png | 0.673883 | eth-alpha-1 |
40 | eth-alpha-1-rt-41.png | 0.701621 | eth-alpha-1 |
41 | eth-alpha-1-rt-42.png | 0.420853 | eth-alpha-1 |
42 | eth-alpha-1-rt-43.png | 0.262710 | eth-alpha-1 |
43 | eth-alpha-1-rt-45.png | 0.311652 | eth-alpha-1 |
44 | eth-alpha-1-rt-46.png | 0.359945 | eth-alpha-1 |
45 | eth-alpha-1-rt-47.png | 0.292476 | eth-alpha-1 |
46 | eth-alpha-1-rt-48.png | 0.724215 | eth-alpha-1 |
47 | eth-alpha-1-rt-49.png | 0.442131 | eth-alpha-1 |
48 | eth-alpha-1-rt-50.png | 0.720475 | eth-alpha-1 |
49 | eth-alpha-1-rt-51.png | 0.728205 | eth-alpha-1 |
50 | eth-alpha-1-rt-52.png | 0.328282 | eth-alpha-1 |
51 | eth-alpha-1-rt-53.png | 0.400788 | eth-alpha-1 |
52 | eth-alpha-1-rt-54.png | 0.328678 | eth-alpha-1 |
53 | eth-alpha-1-rt-55.png | 0.358884 | eth-alpha-1 |
54 | eth-alpha-1-rt-56.png | 0.718293 | eth-alpha-1 |
55 | eth-alpha-1-rt-57.png | 0.731367 | eth-alpha-1 |
56 | eth-alpha-1-rt-58.png | 0.735776 | eth-alpha-1 |
57 | eth-alpha-1-rt-59.png | 0.463199 | eth-alpha-1 |
58 | eth-alpha-1-rt-60.png | 0.588014 | eth-alpha-1 |
59 | eth-alpha-1-rt-61.png | 0.461032 | eth-alpha-1 |
60 | eth-alpha-1-rt-62.png | 0.249902 | eth-alpha-1 |
61 | eth-alpha-1-rt-63.png | 0.249902 | eth-alpha-1 |
62 | eth-alpha-1-rt-64.png | 0.140544 | eth-alpha-1 |
63 | eth-alpha-1-rt-65.png | 0.253882 | eth-alpha-1 |
64 | eth-alpha-1-rt-66.png | 1.345212 | eth-alpha-1 |
65 | eth-alpha-1-rt-67.png | 0.857869 | eth-alpha-1 |
66 | eth-alpha-1-rt-68.png | 0.627098 | eth-alpha-1 |
67 | eth-alpha-1-rt-69.png | 0.681596 | eth-alpha-1 |
68 | eth-alpha-1-rt-70.png | 0.735732 | eth-alpha-1 |
df_selected.loc[train_idx, :]
fname | potassium_cmolkg | lab | |
---|---|---|---|
15 | eth-alpha-1-rt-16.png | 0.662054 | eth-alpha-1 |
16 | eth-alpha-1-rt-17.png | 0.595782 | eth-alpha-1 |
46 | eth-alpha-1-rt-48.png | 0.724215 | eth-alpha-1 |
49 | eth-alpha-1-rt-51.png | 0.728205 | eth-alpha-1 |
48 | eth-alpha-1-rt-50.png | 0.720475 | eth-alpha-1 |
59 | eth-alpha-1-rt-61.png | 0.461032 | eth-alpha-1 |
47 | eth-alpha-1-rt-49.png | 0.442131 | eth-alpha-1 |
54 | eth-alpha-1-rt-56.png | 0.718293 | eth-alpha-1 |
39 | eth-alpha-1-rt-40.png | 0.673883 | eth-alpha-1 |
36 | eth-alpha-1-rt-37.png | 0.134627 | eth-alpha-1 |
55 | eth-alpha-1-rt-57.png | 0.731367 | eth-alpha-1 |
53 | eth-alpha-1-rt-55.png | 0.358884 | eth-alpha-1 |
56 | eth-alpha-1-rt-58.png | 0.735776 | eth-alpha-1 |
7 | eth-alpha-1-rt-08.png | 0.106628 | eth-alpha-1 |
30 | eth-alpha-1-rt-31.png | 0.358387 | eth-alpha-1 |
61 | eth-alpha-1-rt-63.png | 0.249902 | eth-alpha-1 |
45 | eth-alpha-1-rt-47.png | 0.292476 | eth-alpha-1 |
52 | eth-alpha-1-rt-54.png | 0.328678 | eth-alpha-1 |
37 | eth-alpha-1-rt-38.png | 0.046282 | eth-alpha-1 |
51 | eth-alpha-1-rt-53.png | 0.400788 | eth-alpha-1 |
44 | eth-alpha-1-rt-46.png | 0.359945 | eth-alpha-1 |
5 | eth-alpha-1-rt-06.png | 0.495470 | eth-alpha-1 |
65 | eth-alpha-1-rt-67.png | 0.857869 | eth-alpha-1 |
11 | eth-alpha-1-rt-12.png | 0.292252 | eth-alpha-1 |
58 | eth-alpha-1-rt-60.png | 0.588014 | eth-alpha-1 |
64 | eth-alpha-1-rt-66.png | 1.345212 | eth-alpha-1 |
57 | eth-alpha-1-rt-59.png | 0.463199 | eth-alpha-1 |
68 | eth-alpha-1-rt-70.png | 0.735732 | eth-alpha-1 |
0 | eth-alpha-1-rt-01.png | 0.238276 | eth-alpha-1 |
13 | eth-alpha-1-rt-14.png | 0.210804 | eth-alpha-1 |
66 | eth-alpha-1-rt-68.png | 0.627098 | eth-alpha-1 |
63 | eth-alpha-1-rt-65.png | 0.253882 | eth-alpha-1 |
38 | eth-alpha-1-rt-39.png | 0.396687 | eth-alpha-1 |
43 | eth-alpha-1-rt-45.png | 0.311652 | eth-alpha-1 |
33 | eth-alpha-1-rt-34.png | 0.105971 | eth-alpha-1 |
25 | eth-alpha-1-rt-26.png | 0.155182 | eth-alpha-1 |
19 | eth-alpha-1-rt-20.png | 0.342816 | eth-alpha-1 |
27 | eth-alpha-1-rt-28.png | 0.676835 | eth-alpha-1 |
60 | eth-alpha-1-rt-62.png | 0.249902 | eth-alpha-1 |
10 | eth-alpha-1-rt-11.png | 0.335250 | eth-alpha-1 |
34 | eth-alpha-1-rt-35.png | 0.085691 | eth-alpha-1 |
40 | eth-alpha-1-rt-41.png | 0.701621 | eth-alpha-1 |
67 | eth-alpha-1-rt-69.png | 0.681596 | eth-alpha-1 |
20 | eth-alpha-1-rt-21.png | 0.398024 | eth-alpha-1 |
1 | eth-alpha-1-rt-02.png | 0.209848 | eth-alpha-1 |
24 | eth-alpha-1-rt-25.png | 0.171825 | eth-alpha-1 |
8 | eth-alpha-1-rt-09.png | 0.976900 | eth-alpha-1 |
32 | eth-alpha-1-rt-33.png | 0.452482 | eth-alpha-1 |
2 | eth-alpha-1-rt-03.png | 0.255487 | eth-alpha-1 |
21 | eth-alpha-1-rt-22.png | 0.082409 | eth-alpha-1 |
42 | eth-alpha-1-rt-43.png | 0.262710 | eth-alpha-1 |
22 | eth-alpha-1-rt-23.png | 0.312743 | eth-alpha-1 |
62 | eth-alpha-1-rt-64.png | 0.140544 | eth-alpha-1 |
26 | eth-alpha-1-rt-27.png | 0.147696 | eth-alpha-1 |
41 | eth-alpha-1-rt-42.png | 0.420853 | eth-alpha-1 |
def ks_splitter(items): return [train_idx, valid_idx]
= DataBlock(
dblock =(ImageBlock, RegressionBlock),
blocks=ColReader(0, pref='../../_data/ringtrial-tfm/im/'),
get_x=ColReader(1),
get_y= ks_splitter,
splitter =[OrderedQuantize(n_valid=len(valid_idx))],
item_tfms=[
batch_tfms224),
OrderedRatioResize(*imagenet_stats)
Normalize.from_stats(
]
)= dblock.dataloaders(df_selected, bs=16)
dls
= load_learner('./models/frozen-epoch-30-lr-1.5e-3-12102024.pkl', cpu=True)
learn = dls
learn.dls learn.freeze()
learn.lr_find()
SuggestedLRs(valley=0.001737800776027143)
10, 1.5e-3) learn.fit_one_cycle(
epoch | train_loss | valid_loss | r2_score | time |
---|---|---|---|---|
0 | 0.062678 | 0.070659 | -7.250715 | 00:02 |
1 | 0.060883 | 0.032091 | -2.747149 | 00:01 |
2 | 0.064407 | 0.013675 | -0.596789 | 00:01 |
3 | 0.052951 | 0.005511 | 0.356464 | 00:01 |
4 | 0.048459 | 0.005013 | 0.414687 | 00:01 |
5 | 0.043682 | 0.005526 | 0.354736 | 00:01 |
6 | 0.039910 | 0.006038 | 0.294921 | 00:02 |
7 | 0.036788 | 0.006592 | 0.230292 | 00:01 |
8 | 0.034137 | 0.006361 | 0.257186 | 00:01 |
9 | 0.032073 | 0.006753 | 0.211423 | 00:01 |
# val_preds, val_targets = learn.get_preds(dl=dls.valid)
= learn.tta(dl=dls.valid, n=30)
val_preds, val_targets r2_score(val_targets, val_preds)
0.8524203732220106
= val_preds, val_targets
x, y '.')
plt.plot(x, y, # Add the diagonal line
= min(y.min(), x.min())
min_val = max(y.max(), x.max())
max_val 'k--', lw=1) plt.plot([min_val, max_val], [min_val, max_val],
Using cross-validation
def cross_validation(df, target, valid_size=0.2,
=2, epochs=1, lr=1.5e-3,
num_bins=10, seed=31):
n_ttafrom sklearn.model_selection import train_test_split
= df.copy()
df =True, drop=True)
df.reset_index(inplace= train_test_split(df, test_size=valid_size,
train_df, valid_df =pd.qcut(df[target], q=num_bins, labels=False),
stratify=seed)
random_state
= train_df.index, valid_df.index
train_idx, valid_idx
def stratified_splitter(items): return [train_idx, valid_idx]
= DataBlock(
dblock =(ImageBlock, RegressionBlock),
blocks=ColReader(0, pref='../../_data/ringtrial-tfm/im/'),
get_x=ColReader(1),
get_y=stratified_splitter,
splitter=[OrderedQuantize(n_valid=len(valid_idx))],
item_tfms=[
batch_tfms224),
OrderedRatioResize(*imagenet_stats)
Normalize.from_stats(
]
)= dblock.dataloaders(df, bs=16)
dls = load_learner('./models/frozen-epoch-30-lr-1.5e-3-12102024.pkl', cpu=True)
learn = dls
learn.dls
learn.freeze()
learn.fit_one_cycle(epochs, lr)# val_preds, val_targets = learn.get_preds(dl=dls.valid)
= learn.tta(dl=dls.valid, n=n_tta)
val_preds, val_targets return r2_score(val_targets, val_preds)
= df[df.lab == 'kssl']
df_selected = []
scores for seed in range(1, 10):
= cross_validation(df_selected, 'potassium_cmolkg',
score =0.2, num_bins=4,
valid_size=3, seed=seed)
epochs scores.append(score )
epoch | train_loss | valid_loss | r2_score | time |
---|---|---|---|---|
0 | 0.080839 | 0.318125 | 0.302815 | 00:02 |
1 | 0.073392 | 0.298006 | 0.346907 | 00:01 |
2 | 0.065199 | 0.292256 | 0.359508 | 00:01 |
epoch | train_loss | valid_loss | r2_score | time |
---|---|---|---|---|
0 | 0.078712 | 0.058278 | 0.567448 | 00:01 |
1 | 0.103297 | 0.037975 | 0.718140 | 00:01 |
2 | 0.106955 | 0.029388 | 0.781878 | 00:01 |
epoch | train_loss | valid_loss | r2_score | time |
---|---|---|---|---|
0 | 0.040163 | 0.132989 | 0.170330 | 00:02 |
1 | 0.085821 | 0.117681 | 0.265833 | 00:02 |
2 | 0.085760 | 0.110825 | 0.308603 | 00:02 |
epoch | train_loss | valid_loss | r2_score | time |
---|---|---|---|---|
0 | 0.073621 | 0.045193 | 0.553941 | 00:02 |
1 | 0.106387 | 0.029305 | 0.710762 | 00:03 |
2 | 0.082117 | 0.021391 | 0.788873 | 00:02 |
epoch | train_loss | valid_loss | r2_score | time |
---|---|---|---|---|
0 | 0.077322 | 0.299385 | 0.362233 | 00:02 |
1 | 0.071211 | 0.292716 | 0.376440 | 00:01 |
2 | 0.061952 | 0.281709 | 0.399887 | 00:01 |
epoch | train_loss | valid_loss | r2_score | time |
---|---|---|---|---|
0 | 0.155731 | 0.052831 | 0.586981 | 00:02 |
1 | 0.147657 | 0.045921 | 0.641001 | 00:03 |
2 | 0.127009 | 0.043383 | 0.660838 | 00:02 |
epoch | train_loss | valid_loss | r2_score | time |
---|---|---|---|---|
0 | 0.130341 | 0.135900 | 0.309171 | 00:01 |
1 | 0.119079 | 0.105886 | 0.461744 | 00:01 |
2 | 0.102005 | 0.097154 | 0.506130 | 00:01 |
epoch | train_loss | valid_loss | r2_score | time |
---|---|---|---|---|
0 | 0.159416 | 0.036709 | 0.643722 | 00:02 |
1 | 0.147401 | 0.028988 | 0.718660 | 00:02 |
2 | 0.133312 | 0.023979 | 0.767270 | 00:01 |
epoch | train_loss | valid_loss | r2_score | time |
---|---|---|---|---|
0 | 0.186754 | 0.042557 | 0.596722 | 00:01 |
1 | 0.122767 | 0.023109 | 0.781009 | 00:01 |
2 | 0.114166 | 0.018596 | 0.823780 | 00:01 |
plt.hist(np.array(scores))
(array([1., 1., 1., 0., 1., 0., 0., 1., 0., 4.]),
array([0.26426539, 0.31865215, 0.37303891, 0.42742567, 0.48181243,
0.53619919, 0.59058595, 0.64497271, 0.69935947, 0.75374623,
0.80813299]),
<BarContainer object of 10 artists>)
def stratified_split(df, target, valid_size=0.2, test_size=0.2, num_bins=2, seed=41):
from sklearn.model_selection import train_test_split
= df.copy()
df =True, drop=True)
df.reset_index(inplace= train_test_split(df, test_size=test_size,
train_df, test_df =pd.qcut(df[target], q=num_bins, labels=False),
stratify=seed)
random_state
= train_test_split(train_df, test_size=test_size,
train_df, valid_df =pd.qcut(train_df[target], q=num_bins, labels=False),
stratify=seed)
random_state
return train_df, train_df.index, valid_df, valid_df.index, test_df, test_df.index
= stratified_split(df_selected, 'potassium_cmolkg', valid_size=0.3, test_size=0.2, num_bins=2)
data = data train_df, train_idx, valid_df, valid_idx, test_df, test_idx
# Check they have nothing in common
def has_common_elements(list1, list2): return bool(set(list1) & set(list2))
False)
fc.test_eq(has_common_elements(train_idx, test_idx), False)
fc.test_eq(has_common_elements(train_idx, valid_idx), False) fc.test_eq(has_common_elements(test_idx, valid_idx),
--------------------------------------------------------------------------- AssertionError Traceback (most recent call last) Cell In[422], line 4 1 # Check they have nothing in common 2 def has_common_elements(list1, list2): return bool(set(list1) & set(list2)) ----> 4 fc.test_eq(has_common_elements(train_idx, test_idx), False) 5 fc.test_eq(has_common_elements(train_idx, valid_idx), False) 6 fc.test_eq(has_common_elements(test_idx, valid_idx), False) File ~/mambaforge/envs/uhina/lib/python3.12/site-packages/fastcore/test.py:39, in test_eq(a, b) 37 def test_eq(a,b): 38 "`test` that `a==b`" ---> 39 test(a,b,equals, cname='==') File ~/mambaforge/envs/uhina/lib/python3.12/site-packages/fastcore/test.py:29, in test(a, b, cmp, cname) 27 "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails" 28 if cname is None: cname=cmp.__name__ ---> 29 assert cmp(a,b),f"{cname}:\n{a}\n{b}" AssertionError: ==: True False
train_idx, valid_idx
(Index([26, 40, 17, 67, 22, 13, 16, 41, 49, 18, 2, 9, 47, 44, 52, 63, 61, 37,
57, 59, 66, 30, 54, 33, 7, 55, 15, 46, 19, 53, 29, 60, 4, 3, 64, 45,
11, 20],
dtype='int64'),
Index([21, 58, 14, 65, 34, 56, 50, 6, 35, 62, 68, 32, 25, 24, 43, 10, 48], dtype='int64'))
len(train_df), len(valid_df), len(test_df)
(38, 17, 14)
test_df
fname | potassium_cmolkg | lab | |
---|---|---|---|
28 | kssl-rt-29.png | 0.306528 | kssl |
39 | kssl-rt-40.png | 0.673883 | kssl |
23 | kssl-rt-24.png | 0.173592 | kssl |
0 | kssl-rt-01.png | 0.238276 | kssl |
42 | kssl-rt-43.png | 0.262710 | kssl |
27 | kssl-rt-28.png | 0.676835 | kssl |
31 | kssl-rt-32.png | 0.338556 | kssl |
38 | kssl-rt-39.png | 0.396687 | kssl |
8 | kssl-rt-09.png | 0.976900 | kssl |
1 | kssl-rt-02.png | 0.209848 | kssl |
36 | kssl-rt-37.png | 0.134627 | kssl |
51 | kssl-rt-53.png | 0.400788 | kssl |
5 | kssl-rt-06.png | 0.495470 | kssl |
12 | kssl-rt-13.png | 0.446620 | kssl |
def stratified_splitter(items): return [train_idx, valid_idx]
# dblock = DataBlock(blocks=(ImageBlock, RegressionBlock),
# get_x=ColReader(0, pref='../../_data/ringtrial-tfm/im/'),
# get_y=ColReader(1),
# splitter=stratified_splitter,
# batch_tfms=[RatioResize(224)],
# item_tfms=[Quantize(n_valid=len(valid_idx))])
= DataBlock(
dblock =(ImageBlock, RegressionBlock),
blocks=ColReader(0, pref='../../_data/ringtrial-tfm/im/'),
get_x=ColReader(1),
get_y=stratified_splitter,
splitter=[OrderedQuantize(n_valid=len(valid_idx))],
item_tfms=[
batch_tfms224),
OrderedRatioResize(*imagenet_stats)
Normalize.from_stats(
] )
# dblock.summary(df_selected)
Setting-up type transforms pipelines
Collecting items from fname potassium_cmolkg lab
483 kssl-rt-01.png 0.238276 kssl
484 kssl-rt-02.png 0.209848 kssl
485 kssl-rt-03.png 0.255487 kssl
486 kssl-rt-04.png 0.404965 kssl
487 kssl-rt-05.png 0.469860 kssl
488 kssl-rt-06.png 0.495470 kssl
489 kssl-rt-07.png 0.393716 kssl
490 kssl-rt-08.png 0.106628 kssl
491 kssl-rt-09.png 0.976900 kssl
492 kssl-rt-10.png 0.315519 kssl
493 kssl-rt-11.png 0.335250 kssl
494 kssl-rt-12.png 0.292252 kssl
495 kssl-rt-13.png 0.446620 kssl
496 kssl-rt-14.png 0.210804 kssl
497 kssl-rt-15.png 0.482117 kssl
498 kssl-rt-16.png 0.662054 kssl
499 kssl-rt-17.png 0.595782 kssl
500 kssl-rt-18.png 0.360761 kssl
501 kssl-rt-19.png 0.340229 kssl
502 kssl-rt-20.png 0.342816 kssl
503 kssl-rt-21.png 0.398024 kssl
504 kssl-rt-22.png 0.082409 kssl
505 kssl-rt-23.png 0.312743 kssl
506 kssl-rt-24.png 0.173592 kssl
507 kssl-rt-25.png 0.171825 kssl
508 kssl-rt-26.png 0.155182 kssl
509 kssl-rt-27.png 0.147696 kssl
510 kssl-rt-28.png 0.676835 kssl
511 kssl-rt-29.png 0.306528 kssl
512 kssl-rt-30.png 0.351649 kssl
513 kssl-rt-31.png 0.358387 kssl
514 kssl-rt-32.png 0.338556 kssl
515 kssl-rt-33.png 0.452482 kssl
516 kssl-rt-34.png 0.105971 kssl
517 kssl-rt-35.png 0.085691 kssl
518 kssl-rt-36.png 0.152824 kssl
519 kssl-rt-37.png 0.134627 kssl
520 kssl-rt-38.png 0.046282 kssl
521 kssl-rt-39.png 0.396687 kssl
522 kssl-rt-40.png 0.673883 kssl
523 kssl-rt-41.png 0.701621 kssl
524 kssl-rt-42.png 0.420853 kssl
525 kssl-rt-43.png 0.262710 kssl
526 kssl-rt-45.png 0.311652 kssl
527 kssl-rt-46.png 0.359945 kssl
528 kssl-rt-47.png 0.292476 kssl
529 kssl-rt-48.png 0.724215 kssl
530 kssl-rt-49.png 0.442131 kssl
531 kssl-rt-50.png 0.720475 kssl
532 kssl-rt-51.png 0.728205 kssl
533 kssl-rt-52.png 0.328282 kssl
534 kssl-rt-53.png 0.400788 kssl
535 kssl-rt-54.png 0.328678 kssl
536 kssl-rt-55.png 0.358884 kssl
537 kssl-rt-56.png 0.718293 kssl
538 kssl-rt-57.png 0.731367 kssl
539 kssl-rt-58.png 0.735776 kssl
540 kssl-rt-59.png 0.463199 kssl
541 kssl-rt-60.png 0.588014 kssl
542 kssl-rt-61.png 0.461032 kssl
543 kssl-rt-62.png 0.249902 kssl
544 kssl-rt-63.png 0.249902 kssl
545 kssl-rt-64.png 0.140544 kssl
546 kssl-rt-65.png 0.253882 kssl
547 kssl-rt-66.png 1.345212 kssl
548 kssl-rt-67.png 0.857869 kssl
549 kssl-rt-68.png 0.627098 kssl
550 kssl-rt-69.png 0.681596 kssl
551 kssl-rt-70.png 0.735732 kssl
Found 69 items
2 datasets of sizes 38,17
Setting up Pipeline: ColReader -- {'cols': 0, 'pref': '../../_data/ringtrial-tfm/im/', 'suff': '', 'label_delim': None} -> PILBase.create
Setting up Pipeline: ColReader -- {'cols': 1, 'pref': '', 'suff': '', 'label_delim': None} -> RegressionSetup -- {'c': None}
Building one sample
Pipeline: ColReader -- {'cols': 0, 'pref': '../../_data/ringtrial-tfm/im/', 'suff': '', 'label_delim': None} -> PILBase.create
starting from
fname kssl-rt-27.png
potassium_cmolkg 0.147696
lab kssl
Name: 509, dtype: object
applying ColReader -- {'cols': 0, 'pref': '../../_data/ringtrial-tfm/im/', 'suff': '', 'label_delim': None} gives
../../_data/ringtrial-tfm/im/kssl-rt-27.png
applying PILBase.create gives
PILImage mode=RGB size=669x221
Pipeline: ColReader -- {'cols': 1, 'pref': '', 'suff': '', 'label_delim': None} -> RegressionSetup -- {'c': None}
starting from
fname kssl-rt-27.png
potassium_cmolkg 0.147696
lab kssl
Name: 509, dtype: object
applying ColReader -- {'cols': 1, 'pref': '', 'suff': '', 'label_delim': None} gives
0.14769560487272498
applying RegressionSetup -- {'c': None} gives
tensor(0.1477)
Final sample: (PILImage mode=RGB size=669x221, tensor(0.1477))
Collecting items from fname potassium_cmolkg lab
483 kssl-rt-01.png 0.238276 kssl
484 kssl-rt-02.png 0.209848 kssl
485 kssl-rt-03.png 0.255487 kssl
486 kssl-rt-04.png 0.404965 kssl
487 kssl-rt-05.png 0.469860 kssl
488 kssl-rt-06.png 0.495470 kssl
489 kssl-rt-07.png 0.393716 kssl
490 kssl-rt-08.png 0.106628 kssl
491 kssl-rt-09.png 0.976900 kssl
492 kssl-rt-10.png 0.315519 kssl
493 kssl-rt-11.png 0.335250 kssl
494 kssl-rt-12.png 0.292252 kssl
495 kssl-rt-13.png 0.446620 kssl
496 kssl-rt-14.png 0.210804 kssl
497 kssl-rt-15.png 0.482117 kssl
498 kssl-rt-16.png 0.662054 kssl
499 kssl-rt-17.png 0.595782 kssl
500 kssl-rt-18.png 0.360761 kssl
501 kssl-rt-19.png 0.340229 kssl
502 kssl-rt-20.png 0.342816 kssl
503 kssl-rt-21.png 0.398024 kssl
504 kssl-rt-22.png 0.082409 kssl
505 kssl-rt-23.png 0.312743 kssl
506 kssl-rt-24.png 0.173592 kssl
507 kssl-rt-25.png 0.171825 kssl
508 kssl-rt-26.png 0.155182 kssl
509 kssl-rt-27.png 0.147696 kssl
510 kssl-rt-28.png 0.676835 kssl
511 kssl-rt-29.png 0.306528 kssl
512 kssl-rt-30.png 0.351649 kssl
513 kssl-rt-31.png 0.358387 kssl
514 kssl-rt-32.png 0.338556 kssl
515 kssl-rt-33.png 0.452482 kssl
516 kssl-rt-34.png 0.105971 kssl
517 kssl-rt-35.png 0.085691 kssl
518 kssl-rt-36.png 0.152824 kssl
519 kssl-rt-37.png 0.134627 kssl
520 kssl-rt-38.png 0.046282 kssl
521 kssl-rt-39.png 0.396687 kssl
522 kssl-rt-40.png 0.673883 kssl
523 kssl-rt-41.png 0.701621 kssl
524 kssl-rt-42.png 0.420853 kssl
525 kssl-rt-43.png 0.262710 kssl
526 kssl-rt-45.png 0.311652 kssl
527 kssl-rt-46.png 0.359945 kssl
528 kssl-rt-47.png 0.292476 kssl
529 kssl-rt-48.png 0.724215 kssl
530 kssl-rt-49.png 0.442131 kssl
531 kssl-rt-50.png 0.720475 kssl
532 kssl-rt-51.png 0.728205 kssl
533 kssl-rt-52.png 0.328282 kssl
534 kssl-rt-53.png 0.400788 kssl
535 kssl-rt-54.png 0.328678 kssl
536 kssl-rt-55.png 0.358884 kssl
537 kssl-rt-56.png 0.718293 kssl
538 kssl-rt-57.png 0.731367 kssl
539 kssl-rt-58.png 0.735776 kssl
540 kssl-rt-59.png 0.463199 kssl
541 kssl-rt-60.png 0.588014 kssl
542 kssl-rt-61.png 0.461032 kssl
543 kssl-rt-62.png 0.249902 kssl
544 kssl-rt-63.png 0.249902 kssl
545 kssl-rt-64.png 0.140544 kssl
546 kssl-rt-65.png 0.253882 kssl
547 kssl-rt-66.png 1.345212 kssl
548 kssl-rt-67.png 0.857869 kssl
549 kssl-rt-68.png 0.627098 kssl
550 kssl-rt-69.png 0.681596 kssl
551 kssl-rt-70.png 0.735732 kssl
Found 69 items
2 datasets of sizes 38,17
Setting up Pipeline: ColReader -- {'cols': 0, 'pref': '../../_data/ringtrial-tfm/im/', 'suff': '', 'label_delim': None} -> PILBase.create
Setting up Pipeline: ColReader -- {'cols': 1, 'pref': '', 'suff': '', 'label_delim': None} -> RegressionSetup -- {'c': None}
Setting up after_item: Pipeline: OrderedQuantize -- {'n_valid': 17, 'p': 1.0} -> ToTensor
Setting up before_batch: Pipeline:
Setting up after_batch: Pipeline: OrderedRatioResize -- {'max_sz': 224, 'resamples': (<Resampling.BILINEAR: 2>, <Resampling.NEAREST: 0>)} -> IntToFloatTensor -- {'div': 255.0, 'div_mask': 1} -> Normalize -- {'mean': tensor([[[[0.4850]],
[[0.4560]],
[[0.4060]]]], device='mps:0'), 'std': tensor([[[[0.2290]],
[[0.2240]],
[[0.2250]]]], device='mps:0'), 'axes': (0, 2, 3)}
Building one batch
Applying item_tfms to the first sample:
Pipeline: OrderedQuantize -- {'n_valid': 17, 'p': 1.0} -> ToTensor
starting from
(PILImage mode=RGB size=669x221, tensor(0.1477))
applying OrderedQuantize -- {'n_valid': 17, 'p': 1.0} gives
(PILImage mode=RGB size=669x221, tensor(0.1477))
applying ToTensor gives
(TensorImage of size 3x221x669, tensor(0.1477))
Adding the next 3 samples
No before_batch transform to apply
Collating items in a batch
Applying batch_tfms to the batch built
Pipeline: OrderedRatioResize -- {'max_sz': 224, 'resamples': (<Resampling.BILINEAR: 2>, <Resampling.NEAREST: 0>)} -> IntToFloatTensor -- {'div': 255.0, 'div_mask': 1} -> Normalize -- {'mean': tensor([[[[0.4850]],
[[0.4560]],
[[0.4060]]]], device='mps:0'), 'std': tensor([[[[0.2290]],
[[0.2240]],
[[0.2250]]]], device='mps:0'), 'axes': (0, 2, 3)}
starting from
(TensorImage of size 4x3x221x669, tensor([0.1477, 0.7016, 0.3608, 0.6816], device='mps:0'))
applying OrderedRatioResize -- {'max_sz': 224, 'resamples': (<Resampling.BILINEAR: 2>, <Resampling.NEAREST: 0>)} gives
(TensorImage of size 4x3x221x669, tensor([0.1477, 0.7016, 0.3608, 0.6816], device='mps:0'))
applying IntToFloatTensor -- {'div': 255.0, 'div_mask': 1} gives
(TensorImage of size 4x3x221x669, tensor([0.1477, 0.7016, 0.3608, 0.6816], device='mps:0'))
applying Normalize -- {'mean': tensor([[[[0.4850]],
[[0.4560]],
[[0.4060]]]], device='mps:0'), 'std': tensor([[[[0.2290]],
[[0.2240]],
[[0.2250]]]], device='mps:0'), 'axes': (0, 2, 3)} gives
(TensorImage of size 4x3x221x669, tensor([0.1477, 0.7016, 0.3608, 0.6816], device='mps:0'))
= dblock.dataloaders(df_selected, bs=16) dls
dls.train.n, dls.valid.n
(38, 17)
=6, ncols=2, figsize=(12, 13)) dls.show_batch(nrows
= dls learn.dls
# learn.summary()
learn.freeze()
# learn.summary()
learn.lr_find()
SuggestedLRs(valley=0.0014454397605732083)
1, 1.5e-3) learn.fit_one_cycle(
epoch | train_loss | valid_loss | r2_score | time |
---|---|---|---|---|
0 | 0.033352 | 0.007514 | 0.877129 | 00:02 |
= learn.get_preds(dl=dls.valid)
val_preds, val_targets r2_score(val_targets, val_preds)
--------------------------------------------------------------------------- InvalidParameterError Traceback (most recent call last) Cell In[131], line 2 1 val_preds, val_targets = learn.get_preds(dl=dls.valid) ----> 2 r2_score(val_targets, val_preds) File ~/mambaforge/envs/uhina/lib/python3.12/site-packages/sklearn/utils/_param_validation.py:203, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs) 200 to_ignore += ["self", "cls"] 201 params = {k: v for k, v in params.arguments.items() if k not in to_ignore} --> 203 validate_parameter_constraints( 204 parameter_constraints, params, caller_name=func.__qualname__ 205 ) 207 try: 208 with config_context( 209 skip_parameter_validation=( 210 prefer_skip_nested_validation or global_skip_validation 211 ) 212 ): File ~/mambaforge/envs/uhina/lib/python3.12/site-packages/sklearn/utils/_param_validation.py:95, in validate_parameter_constraints(parameter_constraints, params, caller_name) 89 else: 90 constraints_str = ( 91 f"{', '.join([str(c) for c in constraints[:-1]])} or" 92 f" {constraints[-1]}" 93 ) ---> 95 raise InvalidParameterError( 96 f"The {param_name!r} parameter of {caller_name} must be" 97 f" {constraints_str}. Got {param_val!r} instead." 98 ) InvalidParameterError: The 'y_true' parameter of r2_score must be an array-like. Got None instead.
Evaluate fine-tuned model
len(test_df)
14
= DataBlock(blocks=(ImageBlock, RegressionBlock),
dblock =ColReader(0, pref='../../_data/ringtrial-tfm/im/'),
get_x=ColReader(1),
get_y=RandomSplitter(valid_pct=0, seed=41),
splitter=[OrderedQuantize(n_valid=len(test_df))],
item_tfms=[
batch_tfms224),
OrderedRatioResize(*imagenet_stats)]
Normalize.from_stats(
)
= dblock.dataloaders(test_df, bs=len(test_df)) dls
= learn.get_preds(dl=dls.train) val_preds, val_targets
r2_score(val_targets, val_preds)
0.18065004067315227
= learn.tta(dl=dls.train, n=30) val_preds, val_targets
r2_score(val_targets, val_preds)
0.31174768615708437
np.c_[val_preds, val_targets]
array([[0.28596842, 0.33855578],
[0.34872288, 0.30652836],
[0.16156155, 0.1735919 ],
[0.15021066, 0.20984755],
[0.7913804 , 0.6738828 ],
[0.6016631 , 0.9769003 ],
[0.12963496, 0.23827647],
[0.476924 , 0.3966867 ],
[0.41498667, 0.44661984],
[0.3885079 , 0.26271036],
[0.4440051 , 0.49547035],
[0.1741274 , 0.13462704],
[0.38449523, 0.40078753],
[0.5929916 , 0.6768349 ]], dtype=float32)
lab | lr | n_epochs (fine-tuning) | r2_score | n_tta |
---|---|---|---|---|
iaea-aug2022 | 1.5e-3 | 20 | 0.867 | 30 |
kssl | 1.5e-3 | 20 | 0.931 | 30 |
'/Users/franckalbinet/pro/dev/uhina/_data/ringtrial-tfm/im/kssl-rt-01.png') learn.predict(
((0.2232206165790558,), tensor([0.2232]), tensor([0.2232]))
np.c_[val_preds, val_targets]
array([[0.573713 , 0.66205376],
[0.23120013, 0.25548682],
[0.29060498, 0.23827647],
[0.3890785 , 0.3588835 ],
[0.72690636, 0.6738828 ],
[0.52095914, 0.48211747],
[0.5956749 , 0.6768349 ],
[0.55163294, 0.71829337],
[0.41514462, 0.46985987],
[0.6461811 , 0.73577553],
[0.3656289 , 0.24990232],
[0.6311119 , 0.68159574],
[0.16741446, 0.1476956 ],
[0.3227809 , 0.31551853]], dtype=float32)
= val_preds, val_targets
x, y '.')
plt.plot(x, y, # Add the diagonal line
= min(y.min(), x.min())
min_val = max(y.max(), x.max())
max_val 'k--', lw=1) plt.plot([min_val, max_val], [min_val, max_val],
On single images
def predict_with_transforms(learn, img_path, n_predictions=5):
# Load the image
= PILImage.create(img_path)
img
# Create instances of the transforms
= RatioResize(224)
ratio_resize = Quantize()
quantize
= []
predictions for _ in range(n_predictions):
# Apply transforms
= ratio_resize(img)
img_resized = quantize(img_resized)
img_quantized
# Predict
= learn.predict(img_quantized)
pred, _, _ 0])
predictions.append(pred[
from statistics import mode
# Calculate mean and standard deviation
= np.mean(predictions)
mean_pred = np.std(predictions)
std_pred = np.median(predictions)
median_pred = mode(predictions)
mode_pred return mean_pred, std_pred, median_pred, mode_pred, predictions
test_df
fname | potassium_cmolkg | lab | |
---|---|---|---|
416 | iaea-aug2022-rt-03.png | 0.255487 | iaea-aug2022 |
453 | iaea-aug2022-rt-40.png | 0.673883 | iaea-aug2022 |
414 | iaea-aug2022-rt-01.png | 0.238276 | iaea-aug2022 |
441 | iaea-aug2022-rt-28.png | 0.676835 | iaea-aug2022 |
470 | iaea-aug2022-rt-58.png | 0.735776 | iaea-aug2022 |
423 | iaea-aug2022-rt-10.png | 0.315519 | iaea-aug2022 |
429 | iaea-aug2022-rt-16.png | 0.662054 | iaea-aug2022 |
468 | iaea-aug2022-rt-56.png | 0.718293 | iaea-aug2022 |
428 | iaea-aug2022-rt-15.png | 0.482117 | iaea-aug2022 |
467 | iaea-aug2022-rt-55.png | 0.358884 | iaea-aug2022 |
481 | iaea-aug2022-rt-69.png | 0.681596 | iaea-aug2022 |
440 | iaea-aug2022-rt-27.png | 0.147696 | iaea-aug2022 |
475 | iaea-aug2022-rt-63.png | 0.249902 | iaea-aug2022 |
418 | iaea-aug2022-rt-05.png | 0.469860 | iaea-aug2022 |
'/Users/franckalbinet/pro/dev/uhina/_data/ringtrial-tfm/im/iaea-aug2022-rt-03.png') learn.predict(
((0.22924283146858215,), tensor([0.2292]), tensor([0.2292]))
def predict_with_tta_histogram(learn, img_path, n_tta=40):
# Load the image
= PILImage.create(img_path)
img
# Create a test DataLoader with a single image
= learn.dls.test_dl([img])
test_dl
# Collect predictions
= []
all_preds for _ in range(n_tta):
# Get prediction with TTA (n=1 for a single augmentation each time)
= learn.tta(dl=test_dl, n=1)
preds, _ 0][0].item()) # Assuming single output
all_preds.append(preds[
= np.array(all_preds)
all_preds
# Calculate statistics
= np.mean(all_preds)
mean_pred = np.std(all_preds)
std_pred = np.median(all_preds)
median_pred
return mean_pred, std_pred, median_pred, all_preds
# Use the function
= 'iaea-aug2022-rt-03.png'
fname = Path('/Users/franckalbinet/pro/dev/uhina/_data/ringtrial-tfm/im/') / fname
img_path = predict_with_tta_histogram(learn, img_path, n_tta=30)
mean, std, median, all_preds
print(f"Mean prediction: {mean:.4f}")
print(f"Standard deviation: {std:.4f}")
print(f"Median prediction: {median:.4f}")
print(f"All predictions: {all_preds}")
# If you want to compare with the ground truth
print('Ground truth:', df[df.fname == fname]['potassium_cmolkg'].values[0])
# Plot histogram
=10)
plt.hist(all_preds, bins'Histogram of TTA Predictions')
plt.title('Predicted Value')
plt.xlabel('Frequency')
plt.ylabel( plt.show()
Mean prediction: 0.2245
Standard deviation: 0.0293
Median prediction: 0.2370
All predictions: [0.2538105 0.20756826 0.16517167 0.18890977 0.23950726 0.25089669
0.23727572 0.1606092 0.23708239 0.24203241 0.24409012 0.23063052
0.22467479 0.22609089 0.21201754 0.24700734 0.24322104 0.1814348
0.23694187 0.21401702 0.24518737 0.23962407 0.24665055 0.23783752
0.23432088 0.13502732 0.24622732 0.22676304 0.24990481 0.23013265]
Ground truth: 0.29109
plt.plot(all_preds)
# Canonical fine-tuning
# from fastai.vision.all import *
# # Load the pretrained model
# learn = load_learner('./models/650-4000-epoch-25-lr-3e-3.pkl', cpu=False)
# # Prepare your new data
# path = 'path/to/your/data'
# dls = ImageDataLoaders.from_folder(path, valid_pct=0.2, item_tfms=Resize(224), batch_tfms=aug_transforms())
# # Set the new data
# learn.dls = dls
# # Fine-tune the head of the model
# learn.freeze()
# # alternatively: learn.freeze_to(n)
# learn.lr_find()
# learn.fit_one_cycle(5, 3e-3)
# # Fine-tune the entire model
# learn.unfreeze()
# learn.lr_find()
# learn.fit_one_cycle(5, slice(1e-5, 1e-3))
# learn = vision_learner(dls, resnet18, pretrained=False, metrics=R2Score()).to_fp16()
# learn.lr_find()
# learn.lr_find()
SuggestedLRs(valley=0.002511886414140463)
# learn.fit_one_cycle(5, 3e-3)
Evaluation
# Convert predictions and targets to numpy arrays
def assess_model(val_preds, val_targets):
= val_preds.numpy().flatten()
val_preds = val_targets.numpy()
val_targets
# Create a DataFrame with the results
= pd.DataFrame({
results_df 'Predicted': val_preds,
'Actual': val_targets
})
# Display the first few rows of the results
print(results_df.head())
# Calculate and print the R2 score
from sklearn.metrics import r2_score
= r2_score(val_targets, val_preds)
r2 print(f"R2 Score on validation set: {r2:.4f}")
dls.train.n
69
= learn.get_preds(dl=dls.train)
val_preds, val_targets assess_model(val_preds, val_targets)
Predicted Actual
0 0.046272 0.210804
1 0.528189 0.976900
2 0.465372 0.469860
3 0.258100 0.338556
4 0.112802 0.147696
R2 Score on validation set: 0.7392
= learn.get_preds(dl=dls.train)
val_preds, val_targets = r2_score(val_targets, val_preds); r2 r2
= r2_score(val_targets, val_preds); r2 r2
0.7391959435205914
= []
scores for n in range(1, 20):
= learn.tta(dl=dls.train, n=n)
val_preds, val_targets scores.append(r2_score(val_targets, val_preds))
= list(range(1, 20))
x plt.plot(x, scores)
# EXAMPLE of TTA on single item
# from fastai.vision.all import *
# # Define your TTA transforms
# tta_tfms = [
# RandomResizedCrop(224, min_scale=0.5),
# Flip(),
# Rotate(degrees=(-15, 15)),
# Brightness(max_lighting=0.2),
# Contrast(max_lighting=0.2)
# ]
# # Create a pipeline of TTA transformations
# tta_pipeline = Pipeline(tta_tfms)
# # Load your model
# learn = load_learner('path/to/your/model.pkl')
# # Define the input data (e.g., an image)
# input_data = PILImage.create('path/to/your/image.jpg')
# # Apply TTA transforms to the input data and make predictions
# predictions = []
# for _ in range(5): # Apply 5 different augmentations
# augmented_data = tta_pipeline(input_data)
# prediction = learn.predict(augmented_data)
# predictions.append(prediction)
# # Average the predictions
# average_prediction = sum(predictions) / len(predictions)
# print(average_prediction)
# Assuming you have a new CSV file for your test data
# test_source = '../../_data/ossl-tfm/ossl-tfm-test.csv'
# test_df = pd.read_csv(test_source)
# # Create a new DataLoader for the test data
# test_dl = learn.dls.test_dl(test_df)
# # Get predictions on the test set
# test_preds, test_targets = learn.get_preds(dl=test_dl)
# # Now you can use test_preds and test_targets for further analysis
assess_model(val_preds, val_targets)
Predicted Actual
0 0.312483 0.000000
1 0.126990 0.184960
2 0.365726 0.194201
3 0.239089 0.262364
4 0.402980 0.355799
R2 Score on validation set: 0.8325
assess_model(val_preds_tta, val_targets_tta)
Predicted Actual
0 0.246857 0.000000
1 0.148590 0.184960
2 0.371643 0.194201
3 0.226535 0.262364
4 0.407333 0.355799
R2 Score on validation set: 0.8378
= val_preds
val_preds_np = val_targets
val_targets_np
# Apply the transformation: exp(y) - 1
= np.exp(val_preds_np) - 1
val_preds_transformed = np.exp(val_targets_np) - 1
val_targets_transformed
# Create a DataFrame with the results
= pd.DataFrame({
results_df 'Predicted': val_preds_transformed,
'Actual': val_targets_transformed
})
# Display the first few rows of the results
print(results_df.head())
# Calculate and print the R2 score
from sklearn.metrics import r2_score
= r2_score(val_targets_transformed, val_preds_transformed)
r2 print(f"R2 Score on validation set (after transformation): {r2:.4f}")
# Calculate and print the MAPE, handling zero values
def mean_absolute_percentage_error(y_true, y_pred):
= (y_true != 0)
non_zero return np.mean(np.abs((y_true[non_zero] - y_pred[non_zero]) / y_true[non_zero])) * 100
= mean_absolute_percentage_error(val_targets_transformed, val_preds_transformed)
mape print(f"Mean Absolute Percentage Error (MAPE) on validation set: {mape:.2f}%")
# Calculate and print the MAE as an alternative metric
from sklearn.metrics import mean_absolute_error
= mean_absolute_error(val_targets_transformed, val_preds_transformed)
mae print(f"Mean Absolute Error (MAE) on validation set: {mae:.4f}")
Predicted Actual
0 0.366814 0.00000
1 0.135405 0.20317
2 0.441560 0.21434
3 0.270092 0.30000
4 0.496277 0.42732
R2 Score on validation set (after transformation): 0.6936
Mean Absolute Percentage Error (MAPE) on validation set: 50.72%
Mean Absolute Error (MAE) on validation set: 0.1956
=(6, 6))
plt.figure(figsize
# Use logarithmic bins for the colormap
= plt.hexbin(val_targets, val_preds, gridsize=65,
h ='log', cmap='Spectral_r', mincnt=1,
bins=0.9)
alpha
# Get the actual min and max counts from the hexbin data
= h.get_array()
counts = counts[counts > 0].min() # Minimum non-zero count
min_count = counts.max()
max_count
# Create a logarithmic colorbar
= plt.colorbar(h, label='Count in bin', shrink=0.73)
cb = np.logspace(np.log10(min_count), np.log10(max_count), 5)
tick_locations
cb.set_ticks(tick_locations)f'{int(x)}' for x in tick_locations])
cb.set_ticklabels([
# Add the diagonal line
= min(val_targets.min(), val_preds.min())
min_val = max(val_targets.max(), val_preds.max())
max_val 'k--', lw=1)
plt.plot([min_val, max_val], [min_val, max_val],
# Set labels and title
'Actual Values')
plt.xlabel('Predicted Values')
plt.ylabel('Predicted vs Actual Values (Hexbin with Log Scale)')
plt.title(
# Add grid lines
True, linestyle='--', alpha=0.65)
plt.grid(
# Set the same limits for both axes
plt.xlim(min_val, max_val)
plt.ylim(min_val, max_val)
# Make the plot square
'equal', adjustable='box')
plt.gca().set_aspect(
plt.tight_layout()
plt.show()
# Print the range of counts in the hexbins
print(f"Min non-zero count in hexbins: {min_count}")
print(f"Max count in hexbins: {max_count}")
Min non-zero count in hexbins: 1.0
Max count in hexbins: 157.0
= Path('./models')
path_model / '0.pkl') learn.export(path_model
Inference
= Path('../../_data/ossl-tfm/img')
ossl_source / '0a0a0c647671fd3030cc13ba5432eb88.png') learn.predict(ossl_source
((0.5229991674423218,), tensor([0.5230]), tensor([0.5230]))
'fname'] == '0a0a0c647671fd3030cc13ba5432eb88.png'] df[df[
fname | kex | |
---|---|---|
28867 | 0a0a0c647671fd3030cc13ba5432eb88.png | 0.525379 |
3) - 1 np.exp(
19.085536923187668
Experiments:
Color scale: viridis
| Discretization: percentiles = [i for i in range(60, 100)]
Model | Image Size | Learning Rate | Epochs | R2 Score | Time per Epoch | Finetuning | with axis ticks |
---|---|---|---|---|---|---|---|
ResNet-18 | 100 | 1e-3 | 10 | 0.648 | 05:12 | No | Yes |
ResNet-18 | 224 | 2e-3 | 10 | 0.69 | 07:30 | No | Yes |
ResNet-18 | 750 (original size) | 1e-3 | 10 | 0.71 | 36:00 | No | Yes |
ResNet-18 | 224 | 2e-3 | 20 | 0.704 | 07:30 | No | Yes |
ResNet-18 | 224 | 2e-3 | 10 | 0.71 | 07:00 | No | No |
Discretization: percentiles = [i for i in range(20, 100)]
Model | Image Size | Learning Rate | Epochs | R2 Score | Time per Epoch | Finetuning | with axis ticks | colour scale |
---|---|---|---|---|---|---|---|---|
ResNet-18 | 224 | 2e-3 | 10 | 0.7 | 05:12 | No | No | viridis |
ResNet-18 | 224 | 3e-3 | 10 | 0.71 | 05:12 | No | No | jet |
From now on with axis ticks
is always No
.
Discretization: esimated on 10000
cwt power percentiles [20, 30, 40, 50, 60, 70, 80, 90, 95, 97, 99]
Model | Image Size | Learning Rate | Epochs | R2 Score | Time per Epoch | Finetuning | remark | colour scale |
---|---|---|---|---|---|---|---|---|
ResNet-18 | 224 | 2e-3 | 10 | 0.71 | 05:12 | No | None | jet |
ResNet-18 | 224 | 2e-3 | 10 | 0.685 | 05:12 | No | y range added | jet |
From now on random splitter with 10%
validation and random seed 41
.
Discretization: esimated on 10000
cwt power percentiles [20, 30, 40, 50, 60, 70, 80, 90, 95, 97, 99]
Model | Image Size | Learning Rate | Epochs | R2 Score | Time per Epoch | Finetuning | remark | colour scale |
---|---|---|---|---|---|---|---|---|
ResNet-18 | 224 | 2e-3 | 10 | 0.7 | 05:12 | No | Pre-train & normalize: True | jet |
ResNet-18 | 224 | 2e-3 | 10 | 0.796 | 08:12 | No | No Pre-train | jet |
ResNet-18 | 224 | 3e-3 | 10 | 0.7 | 05:12 | No | Pre-train & normalize: False | jet |
ResNet-18 (id=0) | 224 | 2e-3 | 20 | 0.829 | 08:12 | No | No Pre-train (try 18 epochs) | jet |