11. Transfer Learning for Domain-Specific Image Classification with Small Datasets (2019)
Articles,  Blog

11. Transfer Learning for Domain-Specific Image Classification with Small Datasets (2019)


hello everybody my name is Chris and
today we’re gonna talk about transfer learning I think transfer learning is
really cool because it lets you take a small data
set and actually create a really accurate model we’re gonna do this by
leveraging very large networks that were trained for many hours or even days on
much larger data sets than ours and actually transfer that knowledge into
our own network made specifically for our classification problem so today
we’re going to be working with the Freiburg Grocery data set which is a
small data set about four thousand images of various grocery products and
we want to train our classifier to tell us what type of grocery products those
are so let’s take a look at the data so first we’re going to import our keras
model layers so we can we can build a keras model and we’re also importing
ResNet 50 which is a image classification network actually out of
Microsoft Research that’s very large and we’re going to be
able to transfer knowledge from it into our own network so first we need to load
our training data so here we’re splitting our training data into a train
and a test as well as extracting the class names from this utility library
called groceries and let’s take a look at one of those images actually looks
like there you go jar of pickles hopefully we can train our machine to
tell us that so let’s look at what the other classes in the dataset look like
we’ve got beans cake pasta and my favorite vinegar all right let’s see how
the data is actually distributed so as you can see some of the classes don’t
have nearly as many examples as others hopefully transfer learning can help to
compensate for this so before we can train our model we need
to convert our categories which are going to be numbers between 0 and 25
into 1 hot encoded vectors so we’re calling two categorical on our labels
and now just to see how we can perform on this data set with a very simple
perceptron let’s go ahead normalize our data and then just create a single layer
perceptron model we’re going to use categorical cross-entropy for our loss
because this is a multi-class classification problem our good old
friend the adam optimizer and we also want to view accuracy so we can have a
better metric to comprehend what’s going on lastly we’re calling WandB.init so we
can visualize our metrics and let’s go ahead and train this model okay so looks
like we aren’t doing so well our validation accuracy is point zero four
percent this is this is very troubling our accuracy on the training
data is even lower I mean I’m I look at this and I I feel
ill there has to be a better way so Keras makes it really easy to leverage
the research community’s progress in computer vision models so here we’re
going to import resnet 50 and actually download the pre trained weights from
training on image net which is an image data set with millions of images that
takes many days to train so with this one line we’re pulling in cutting-edge
computer vision research let’s go ahead and take a look at a model summary to
see what this network looks l ike oh man so many layers! ResNet 50 is much more
complicated than our simple perceptron you can see things like batch
normalization many different convolutions and then even this funny
add layer so what ResNet does is it actually branches off and takes features
from earlier in the network and adds them back in in later layers and this
helps the network train better and allows researchers to make an even
deeper network which gives it more expressibility and accuracy I can just
keep on scrolling so to see what this network can actually do let’s run it on
a picture of an elephant because why not so here we’re loading in our elephant
we’re changing its size to 224 pixels by 224 pixels because the network expects
that size then we’re expanding the dimensions because we need to include
our batch dimension and we call this really important function pre process
input so when they trained ResNet the researchers used a very specific way pre
processing the images and we’re going to use their exact same logic to do that on
our own data so that we can have high accuracy results coming out of the model
lastly we just call predict and we’re using this nice helper method decode
predictions which are going to change the the various indices into the last
layer and tell us exactly what category that it’s predicting look at that the
network output Tusker with 49 percent accuracy an indian elephant with
thirty-four percent accuracy and there’s a slight chance we’re looking at an
African elephant now I personally probably wouldn’t be able to tell you
the difference between these three kinds of elephants but a network this powerful
is actually able to do it with a high degree of accuracy but we don’t want
this network to tell us the categories that it a trained on we want it to tell
us our categories for our grocery data set so let’s look at a way that we can
actually do that first let’s take our grocery data set and pre-process it
exactly the same way that the resnet authors did now we can actually go into
the resnet model and pull out specific layers that we want to use in this case
we’re going to pull out the second-to-last
layer which is called the average pool layer and now we can create a new model
with the same input to our resnet model but now instead about putting a
thousand categories we’re gonna output this last layer as our final category so
let’s take a look and see what this model actually looks like still a
massive model but now instead of a thousand categories at the bottom we
have a 2048 length vector which are going to contain what we hope to be the
most important features from our data set so now we can actually take our
pre-processed grocery data set and run it through this new model that we’ve
created and actually extract the features so now we’re going to transform
our images into 2048 length vectors of numbers that we can use to train a new
model on and we hope that resnet has created features that are going to be
much easier to learn from than our original image data we’re going to do
the same for our test data and then finally we can create a new model which
is a simple perceptron again with 25 categories for our data set using the
same loss and optimizer as we did earlier let’s go ahead and fit it and
see if we can get better accuracy than our first try look at that right off the
bat we’re getting into 80% validation accuracy you might also notice that we
have a bit of an over fitting problem but there are actually additional
techniques we can use to ensure that the network generalizes well across our
data set and we can fix this issue so instead of just extracting the features
which is great because it actually makes our model train really fast a
disadvantage is now if we actually deploy this model we’re gonna have to
deploy two models side by side and always put our input imagery through all
of resnet and then separately pass that output into the next model Keras
makes it really easy for us to make a single model where the output of the Resnet
model can just go directly into our perceptron so we do this here but
creating a new model we add our resnet layers and our new final dense layer
then we turn all of our layers to be trainable equals false in the ResNet
network so when we’re training this network we don’t want any of the layers
to train in ResNet instead we’re just going to tune the weights in our final
dense layer so now you can see there are 23 million parameters in this network
but only 51,000 of them are trainable now if we run training you’ll see that
it actually takes a lot longer to train this is because every batch we’re
passing that data all the way through the resnet network and doing all of
those convolutions and different arithmetic so it’s taking much longer as
opposed to using the cached output features that we had used before but the
advantage of this is now you have a single model that you can use to
continue to retrain as your data set maybe grows or you change different
labels in your data set as well as it’s much easier to deploy your models but
you see we’re getting essentially the same accuracy as we were getting by just
extracting the features and then training this last layer. alright so
there’s one more technique we can use with transfer learning that will
actually give us even more accuracy. this is known as fine-tuning so instead of
just training our our layers that we added at the end of the network we can
actually take a subset of the layers in the resonant network and allow them to
Train as well. so the reason behind this is the way these networks tend to learn
is that the layers much higher up tend to extract much more higher level
features things that would be shared common amongst all the classes in your
dataset whereas the layers lower in the network tend to be much
more specific and are looking at shapes different edges that are going to be
very specific to your classes or in this case the classes that ResNet was
trained on so we can actually take these final layers and fine-tune them enable
them to change their weights so that they are better suited for our classes
while still enabling the very generic layers at the top of the network to pass
down the most meaningful information for our new classifier so to do this we
actually set the resnet to be trainable now and we go into the network
and actually in this case just say the final 11 layers out of the hundred or so
that are in resnet we’re going to allow to Train whereas the the first layers
are all not going to be trainable so one thing to note when you’re
fine-tuning is that because the weights have been trained on a very large data
set and are going to be very specific to the ResNet data set when we start to
fine-tune those weights and move them we’re likely going to want to do that
much more slowly than we would in a normal network so this is a case where
instead of just setting optimizer equal to Adam you would want to actually
instantiate a new instance of the optimizer and slow down the learning
rate so we really want to move those weights in the last layers a little by
little and this can really prevent overfitting on our data set which can be
easy given it’s it’s so small and we have so many parameters and look at that with only a few lines
we were able to leverage cutting-edge models to actually get 72 percent
accuracy on our Freiburg grocery data set remember we started at less than 1%
accuracy I’d say that’s a pretty good day at the office!

2 Comments

  • Vassilios Moustakidis

    Hey WANDB guys, amazing videos!!! What´s the reason you choose ResNet50? Instead of, for instance, Inception-v3 or v4? Because ResNet requires lower number of operations (G-Ops)?

    In a Standord paper/ presentation (Fei-Fei Li & Justin Johnson & Serena Yeung) I found a comparison (slide 84-90) of different Image CNNs, so I was curious what the reason is. Thanks

Leave a Reply

Your email address will not be published. Required fields are marked *