Skip to content

Latest commit

 

History

History

episode03

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Training models: tips and tricks

At this stage, you hopefully have the data organized and curated in a way that it is ready to be used to train a deep learning model. In this episode of the BENDER Series, we go over some tips and tricks to make the process of training, and more importantly logging metrics and intermediate results in a manageable manner.


Initial version:

We start with the notebook DermaMNIST Initial Version to demonstrate our recommended process. If you prefer a python script, consider opening this. This is admittedly a first attempt, and there are many improvements we can make.

We invite you to make a copy of the notebook, and then make changes to it (you could simply copy changes from the scripts we point to) as we make progress in the versions below.

Model is not really learning all categories

Note from the image above that only the melanocytic nevi category is being learnt by the model, and since it has the largest representation in both the training and validation/test set, the weighted average accuracy is quite high even though all the other categories have 0 accuracy.

training loss for initial version

Training loss for initial version: seems to reduce with increasing iterations, but then flattens out. When it flattens out, it is really not very useful to train for more iterations, as the accuracy also flattens out. We will see in the third version, how this wasteful training could be avoided using 'validation patience'.

validation accuracy for initial version

Validation accuracy for initial version: even if we stopped at 10000 iterations, we would end up with the same validation accuracy as we do after 80000 iterations.


Second version:

For this version, we modify only the momentum hyperparameter, keeping everything else the same. The difference between v1 and v2 is just the following line:

optimizer = torch.optim.SGD(model.parameters(), lr=0.000005, momentum=0.5)

changed to

optimizer = torch.optim.SGD(model.parameters(), lr=0.000005, momentum=0.9)

Make this change in your notebook, and verify that you see similar results as we do, in the training loss and validation accuracy.

training loss after momemtum change

Training loss after change of momentum to 0.9 from 0.5: note that this is still noisy like v1, but the loss reduces faster: breaching the value of 1.0 is under 10000 iterations as compared to 20000 earlier.

validation accuracy after momemtum change

Validation accuracy after change of momentum to 0.9 from 0.5: there's not much change here, the accuracy appears to remain in the same range, indicating that we should explore modifying other hyperparameters, or, even changing the model itself.


Third version:

In the spirit of tweaking hyperparameters, we make another change here, where we increase the learning rate (one of the most sensitive, and hence most tweaked hyperparameters), so that the network weights can learn 'faster'. The functional difference between v2 and v3 is just the following line:

optimizer = torch.optim.SGD(model.parameters(), lr=0.000005, momentum=0.9)

changed to

optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)

Additionally, we make another modification which doesn't change the training algorithm per-se, but adds a level of verification if the training is going well or not. Each time we evaluate the validation accuracy, we track if it is higher or lower than the previous best. If for a fixed number of evaluations, the accuracy is lower or equal to the previous best, we believe then that the network is not learning anything new, and so we stop training. This fixed number is called the 'validation patience' and it helps avoid situations where the validation accuracy is flat (like in the two versions above), or worse, if the network overfits (where the validation accuracy may reduce).

To make this change in your notebook, look for the patience parameter in this script, and include all the lines that contain it in your training function.

training loss after change in learning rate

Training loss after change of learning rate to 0.005 from 0.000005. Note that the training loss appears to be very noisy, but is lower than the previous two versions.

validation accuracy after momemtum change

Validation accuracy after change of learning rate to 0.005 from 0.000005. This shows a significant improvement already compared to the previous two versions: the accuracies breach the 0.75 level, and the hope is that the test accuracies are similar as well (so that we verify the generalization capability and not overfit on the validation data ;-)).


Fourth version:

For this version (v4), we make a minor change in the optimizer: we switch from SGD (Stochastic Gradient Descent) to Adam (Adaptive Moment), which is known to be 'safer' (see here for more) and less forgiving to hyperparameter variations.

Hence, functionally, the only change here is the following line:

optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)

changed to

optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

However, we also introduce Tensorboard to log the training progress for the same parameters: training (and additionally validation) loss as well as validation (and additionally training) accuracy.

Since the code changes are relatively large compared to the previous versions, we have prepared a new ipynb notebook to start off again, if you prefer. Run this notebook on your own to verify that you see results that confirm improvement over the previous versions, and also browse through the Tensorboard logs (like shown in the videos) to get a better handle over the performance of this model.

Training and Validation losses and accuracies

Here are the training (left) and validation (right) accuracies (top) and losses (bottom) using Tensorboard. The curves are smoothed for better visualization (smoothing = 0.898). Note how the training loss keeps reducing and accuracy keeps rising, but the validation loss plateaus and even rises, while the accuracy plateaus and also falls slightly. Our code is setup to save the model with the highest validation accuracy, but what this indicates is that our model has overfit on the training data, and the next versions attempt to handle just that.

Test accuracy and classification report

The test accuracy for our model is 0.762, and as you can see, all the categories (and not just melanocytic nevi) have a non-zero precision, which is an improvement over the first naive version! Also, 0.762 is already second highest in the DermaMNIST benchmarks on the MedMNIST webpage!


Fifth version:

In this version, to try and avoid overfitting to the training data, we increase the capacity of our model by adding more layers, with the intention of allowing the model to be more expressive and hence learn the nuances of the actual distribution of categories better (than simply memorizing the training data).

To do so, v5 includes just the following change in the model init method:

...
nn.Conv2d(64, 128, (3, 3), padding=1, stride=2, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(),
# (64, 16, 16)
nn.Conv2d(128, 128, (3, 3), padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(),
# (64, 16, 16)
nn.Conv2d(128, 128, (3, 3), padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(),
# (128, 8, 8)
...

Our network now is 6 layers deep (each layer is assumed to mean convolution + ReLU + Batch Norm). Everything else is maintained to be the same, and we continue to use Tensorboard to log and compare the results.

Training and Validation losses and accuracies

In the image above, the blue curves correspond to version 4, and gray curves are for the current version. Note that the training as well as validation losses are lower with this deeper network, while the training accuracy has a higher slope, while the validation accuracy still tracks similar to version 4. This means that we could make the model even more capable/powerful to try and understand this distribution.

Test accuracy and classification report

For this version, the test accuracy is in fact lower than the previous one, and is 0.755 (still competitive in the benchmark!). Note also that the dermatofibroma metrics are all 0, indicating that this model completely ignores this category, yet achieves good average results. This is a good reason to visualize the entire confusion chart to make sure such behavior does not occur!


Sixth version:

This version changes only the network again, and makes it even more deeper, with 8 layers now. Everything else is left the same, so that we are able to investigate the impact of each of these individual changes. The code change is again only in the init method of the network definition, where the following lines are included (see full changes between v5 and v6).

...
nn.ReLU(),
# (64, 16, 16)
nn.Conv2d(128, 256, (3, 3), padding=1, stride=2, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
# (128, 8, 8)
nn.Conv2d(256, 256, (3, 3), padding=1, bias=False),
nn.BatchNorm2d(256),
...

Training and Validation losses and accuracies

In the image above, the blue curves correspond to v4, and pink curves are for v5, and the green ones are for the latest version.

Even though the losses for both training and validation are lower than the previous versions, the overfitting phenomenon still persists (see the validation accuracy curve). This points to the need for regularization: where we make it harder now for the network to memorize, and provide more variations of input data to learn from.

Test accuracy and classification report

For this version, the test accuracy is again lower than the previous one, pointing to the same issues as earlier. In the final version below, we introduce data augmentation to find out if it can resolve these problems.


Seventh version:

Finally, we attempt to address the overfitting problem by including some regularization (a method to better condition the model training) using Data Augmentation. As with previous versions, the only change between v6 and this version is in the load_datasets helper function, where the training data loader uses the following transforms:

training_transform_medmnist = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Pad(2),
        transforms.RandomCrop(
            size=(32, 32), padding=(0, 0, 5, 5), padding_mode="reflect"
        ),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
    ]
)

Training and Validation losses and accuracies

In the image above, the blue curves correspond to v4: small network, cyan curves are for v5: deeper network, pink curves are for v6: even deeper network, and finally the green curves represent the current version.

Note here how the training accuracy appears to be the lowest amongst all the variants, but the validation accuracy does not drop off with more iterations: this is precisely how overfitting is avoided!

Test accuracy and classification report

For this version, we note that the test accuracy is now 0.770, higher than all the benchmarks listed on the MedMNIST webpage! Mote also that the dermatofibroma category is no longer 0 in it's metrics, and nearly all categories have precisions greater than all previous versions.

We hope you had as much fun following along this journey as we did, and that you will experiment with these methods to explore how to train your deep neural networks better!


References

  • For more general tips and tricks around model training (general because it isn't Medical Imaging in particular), Andrei Karpathy's recipe from 2019 is highly recommended. The contents of this episode is an extension of this with focus on medical image data sets.

  • If you're inclined to use MONAI, consider following this tutorial. It follows an older version of the medMNIST data set, and uses MONAI to load the data and build models more easily.

  • For a beginners' guide to image segmentation - the complete workflow using MRI and CT data, see this paper from Dr. Leticia Rittner's group at UniCAMP. They have associated code on GitHub to follow along as well.

  • This integrated medical image visualization tool for jupyter notebooks called itkwidgets and this getting started guide with MONAI could be very useful!

  • For medical image visualization within tensorboard ("tensorboard3d"), this nice plugin developed by Kitware could be super useful for volumetric data.

  • This blog post for more great tips while training models. This was a previous entry to the MICCAI Education Challenge as well!

  • MMAR is a Medical Model Archive designed by NVIDIA to organize all artifacts produced during the model development life cycle. This may be too heavyweight for research prototypes that one may begin with, but as things become more stable, such standards (much like BIDS for data storage) may help avoid further pains down the line.

  • Consider following the MICCAI Hackathon reproducibility checklist to ensure that your pipeline is not too exotic, and future researchers can build on your work!

For questions/suggestions for improvements, please create an issue in the BENDER repository.