How to fine-tune a keras model with existing plus newer classes?

南笙酒味 提交于 2021-01-27 14:03:10

问题


Good day!

I have a celebrity dataset on which I want to fine-tune a keras built-in model. SO far what I have explored and done, we remove the top layers of the original model (or preferably, pass the include_top=False) and add our own layers, and then train our newly added layers while keeping the previous layers frozen. This whole thing is pretty much like intuitive.

Now what I require is, that my model learns to identify the celebrity faces, while also being able to detect all the other objects it has been trained on before. Originally, the models trained on imagenet come with an output layer of 1000 neurons, each representing a separate class. I'm confused about how it should be able to detect the new classes? All the transfer learning and fine-tuning articles and blogs tell us to replace the original 1000-neuron output layer with a different N-neuron layer (N=number of new classes). In my case, I have two celebrities, so if I have a new layer with 2 neurons, I don't know how the model is going to classify the original 1000 imagenet objects.

I need a pointer on this whole thing, that how exactly can I have a pre-trained model taught two new celebrity faces while also maintaining its ability to recognize all the 1000 imagenet objects as well.

Thanks!


回答1:


CNN's are prone to forgetting the previously learned knowledge when retrained for a new task on a novel domain and this phenomenon is often called catastrophic forgetting, which is an active and challenging research domain.

Coming to the point, one obvious way to enable a model to classify new classes along with old classes is to train from scratch on the accumulated (old+new) dataset (which is time consuming.)

In contrast, several alternative approaches have been proposed in the literature of (class-incremental) continual learning to tackle this scenario in the recent years:

  1. Firstly, you can use a small subset of the old dataset along with the new dataset to train your new model, refered as rehearsal-based approach. Note that you can train a GAN to generate pseudo samples of old classes instead of storing a subset of raw samples. As depicted in the figure, while training, distillation loss is used to mimic the prediction of old model (weight is frizzed) to the new model and it helps to avoid forgetting old knowledge:
  2. Secondly, as the contributions of each neuron in a model are not equal, while training the new model you may instead only update neurons that are less important for old classes so that we can retain old knowledge. You can check out the Elastic Weight Consolidation (EWC) paper for more details.
  3. Thirdly, you can grow your model dynamically to extract features that are specific for new classes without harming the weights that are important for old classes. You can check out Dynamically Extendable Network (DEN) for more details.



回答2:


With transfer learning, you can make the trained model classify among the new classes on which you just trained using the features learned from the new dataset and the features learned by the model from the dataset on which it was trained in the first place. Unfortunately, you can not make the model to classify between all the classes (original dataset classes + second time used dataset classes), because when you add the new classes, it keeps their weights only for classification. But, let's say for experimentation you change the number of output neurons (equal to the number of old + new classes) in the last layer, then it will now give random weights to these neurons which on prediction will not give you meaningful result.

This whole thing of making the model to classify among old + new classes experimentation is still in research area. However, one way you can achieve it is to train your model from scratch on the whole data (old + new).



来源:https://stackoverflow.com/questions/58027839/how-to-fine-tune-a-keras-model-with-existing-plus-newer-classes

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!