Change loss function dynamically during training in Keras, without recompiling other model properties like optimizer

末鹿安然 提交于 2020-05-13 02:01:07

问题


Is it possible to set model.loss in a callback without re-compiling model.compile(...) after (since then the optimizer states are reset), and just recompiling model.loss, like for example:

class NewCallback(Callback):

        def __init__(self):
            super(NewCallback,self).__init__()

        def on_epoch_end(self, epoch, logs={}):
            self.model.loss=[loss_wrapper(t_change, current_epoch=epoch)]
            self.model.compile_only_loss() # is there a version or hack of 
                                           # model.compile(...) like this?

To expand more with previous examples on stackoverflow:

To achieve a loss function which depends on the epoch number, like (as in this stackoverflow question):

def loss_wrapper(t_change, current_epoch):
    def custom_loss(y_true, y_pred):
        c_epoch = K.get_value(current_epoch)
        if c_epoch < t_change:
            # compute loss_1
        else:
            # compute loss_2
    return custom_loss

where "current_epoch" is a Keras variable updated with a callback:

current_epoch = K.variable(0.)
model.compile(optimizer=opt, loss=loss_wrapper(5, current_epoch), 
metrics=...)

class NewCallback(Callback):
    def __init__(self, current_epoch):
        self.current_epoch = current_epoch

    def on_epoch_end(self, epoch, logs={}):
        K.set_value(self.current_epoch, epoch)

One can essentially turn python code into compositions of backend functions for the loss to work as follows:

def loss_wrapper(t_change, current_epoch):
    def custom_loss(y_true, y_pred):
        # compute loss_1 and loss_2
        bool_case_1=K.less(current_epoch,t_change)
        num_case_1=K.cast(bool_case_1,"float32")
        loss = (num_case_1)*loss_1 + (1-num_case_1)*loss_2
        return loss
    return custom_loss
it works.

I am not satisfied with these hacks, and wonder, is it possible to set model.loss in a callback without re-compiling model.compile(...) after (since then the optimizer states are reset), and just recompiling model.loss?

来源:https://stackoverflow.com/questions/55979176/change-loss-function-dynamically-during-training-in-keras-without-recompiling-o

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!