MNIST - Data Augmentation Gone Wrong
What happens to MNIST accuracy when input data is horizontally flipped.
#collapse-hide
from fastai2.vision.all import *
from utils import *
path = untar_data(URLs.MNIST)
train_dir = path/'training'
#val_dir = path/'testing'
fns_train = get_image_files(train_dir)
#fns_val = get_image_files(val_dir)
print('train files: ', len(fns_train))
#print('val files: ', len(fns_val))
batch_tfms = [Flip(p=1)] # horizontal flip
db = DataBlock(
blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
splitter = RandomSplitter(valid_pct=0.2, seed=42),
get_y = parent_label,
batch_tfms = None
)
db_flip = DataBlock(
blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
splitter = RandomSplitter(valid_pct=0.2, seed=42),
get_y = parent_label,
batch_tfms = batch_tfms
)
dls = db.dataloaders(train_dir, bs=256)
dls_flip = db_flip.dataloaders(train_dir, bs=256)
dls.show_batch(ncols=5,nrows=1)
dls_flip.show_batch(ncols=5,nrows=1)
learn = cnn_learner(dls, resnet18,
pretrained=False,
metrics=accuracy)
lr_min = learn.lr_find()[0]
f'lr_min: {lr_min:0.05f}'
# no horizontal flip
learn.fit_one_cycle(5, lr_min)
- With baseline MNIST, resnet18 is getting 99% accuracy
- Note: train_loss and valid_loss are both low
learn = cnn_learner(dls_flip, resnet18,
pretrained=False,
metrics=accuracy)
lr_min = learn.lr_find()[0]
f'lr_min: {lr_min:0.05f}'
# yes horizontal flip
learn.fit_one_cycle(5, lr_min)
- With horizontally flipped numbers, accuracy dropped to ~41%
- Note, train_loss is a lot lower than valid_loss -> overfitting
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_top_losses(4, nrows=1)
- Model is predicting a 5, when seeing a 2
interp.plot_confusion_matrix()
- Model is predicting 0, 1, 4, and 8 correctly -> 40% accuracy
- Model confuses 5 for 2 | 6 for 2 | 3 for 8 | 9 for 8
- Does this make sense?
interp.most_confused()[:5]
# top number is actual, bottom number is prediction
learn.show_results(max_n=12)