1) Import libraries, and setup file paths

#collapse-hide
from fastai2.vision.all import *
from utils import *

path = untar_data(URLs.MNIST)

train_dir = path/'training'
#val_dir = path/'testing'
fns_train = get_image_files(train_dir)
#fns_val = get_image_files(val_dir)

print('train files: ', len(fns_train))
#print('val files: ', len(fns_val))

train files:  60000

2) Setup two dataloaders: baseline, horizontal flip

batch_tfms = [Flip(p=1)] # horizontal flip

db = DataBlock(
    blocks = (ImageBlock, CategoryBlock), 
    get_items = get_image_files, 
    splitter = RandomSplitter(valid_pct=0.2, seed=42),
    get_y = parent_label,
    batch_tfms = None
)

db_flip = DataBlock(
    blocks = (ImageBlock, CategoryBlock), 
    get_items = get_image_files, 
    splitter = RandomSplitter(valid_pct=0.2, seed=42),
    get_y = parent_label,
    batch_tfms = batch_tfms
)

dls = db.dataloaders(train_dir, bs=256)
dls_flip = db_flip.dataloaders(train_dir, bs=256)

3) Check each dataloader is working

dls.show_batch(ncols=5,nrows=1)
dls_flip.show_batch(ncols=5,nrows=1)

4) Train resnet18 on baseline, and check accuracy

learn = cnn_learner(dls, resnet18,
                    pretrained=False,
                    metrics=accuracy)
lr_min = learn.lr_find()[0]
f'lr_min: {lr_min:0.05f}'
# no horizontal flip
learn.fit_one_cycle(5, lr_min)
epoch train_loss valid_loss accuracy time
0 0.207485 0.554384 0.887250 00:08
1 0.106351 0.065353 0.983833 00:08
2 0.073446 0.105541 0.972750 00:08
3 0.039243 0.042821 0.988333 00:08
4 0.017257 0.026517 0.992750 00:08
  • With baseline MNIST, resnet18 is getting 99% accuracy
  • Note: train_loss and valid_loss are both low

5) Train new resnet18 on horizontally flipped dataset

learn = cnn_learner(dls_flip, resnet18,
                    pretrained=False,
                    metrics=accuracy)
lr_min = learn.lr_find()[0]
f'lr_min: {lr_min:0.05f}'
# yes horizontal flip
learn.fit_one_cycle(5, lr_min)
epoch train_loss valid_loss accuracy time
0 0.215689 6.608447 0.339833 00:09
1 0.098538 3.857758 0.394500 00:09
2 0.076361 3.930472 0.387083 00:09
3 0.040679 3.130253 0.437750 00:09
4 0.015541 3.993914 0.410417 00:09
  • With horizontally flipped numbers, accuracy dropped to ~41%
  • Note, train_loss is a lot lower than valid_loss -> overfitting

6) What happened?

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_top_losses(4, nrows=1)
  • Model is predicting a 5, when seeing a 2
interp.plot_confusion_matrix()
  • Model is predicting 0, 1, 4, and 8 correctly -> 40% accuracy
  • Model confuses 5 for 2 | 6 for 2 | 3 for 8 | 9 for 8
  • Does this make sense?
interp.most_confused()[:5]
[('6', '2', 969),
 ('9', '8', 935),
 ('3', '8', 855),
 ('5', '2', 823),
 ('2', '5', 810)]
# top number is actual, bottom number is prediction
learn.show_results(max_n=12)