How to use model input in loss function?

梦想的初衷 提交于 2021-02-10 03:28:07

问题


I am trying to use a custom loss-function which depends on some arguments that the model does not have.

The model has two inputs (mel_specs and pred_inp) and expects a labels tensor for training:

def to_keras_example(example):
    # Preparing inputs
    return (mel_specs, pred_inp), labels

# Is a tf.train.Dataset for model.fit(train_data, ...)
train_data = load_dataset(fp, 'train).map(to_keras_example).repeat()

In my loss function I need to calculate the lengths of mel_specs and pred_inp. This means my loss looks like this:

def rnnt_loss_wrapper(y_true, y_pred, mel_specs_inputs_):
    input_lengths = get_padded_length(mel_specs_inputs_[:, :, 0])
    label_lengths = get_padded_length(y_true)
    return rnnt_loss(
        acts=y_pred,
        labels=tf.cast(y_true, dtype=tf.int32),
        input_lengths=input_lengths,
        label_lengths=label_lengths
    )

However, no matter which approach I choose, I am facing some issue.


Option 1) Setting the loss-function in model.compile()

If I actually wrap the loss function s.t. it returns a function which takes y_true and y_pred like this:

def rnnt_loss_wrapper(mel_specs_inputs_):
    def inner_(y_true, y_pred):
        input_lengths = get_padded_length(mel_specs_inputs_[:, :, 0])
        label_lengths = get_padded_length(y_true)
        return rnnt_loss(
            acts=y_pred,
            labels=tf.cast(y_true, dtype=tf.int32),
            input_lengths=input_lengths,
            label_lengths=label_lengths
        )
    return inner_

model = create_model(hparams)
model.compile(
    optimizer=optimizer,
    loss=rnnt_loss_wrapper(model.inputs[0]
)

Here I get a _SymbolicException after calling model.fit():

tensorflow.python.eager.core._SymbolicException: Inputs to eager execution function cannot be Keras symbolic tensors, but found [...]

Option 2) Using model.add_loss()

The documentation of add_loss() states:

[Adds a..] loss tensor(s), potentially dependent on layer inputs.
..
Arguments:
  losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
    may also be zero-argument callables which create a loss tensor.
  inputs: Ignored when executing eagerly. If anything ...

So I tried to do the following:

def rnnt_loss_wrapper(y_true, y_pred, mel_specs_inputs_):
    input_lengths = get_padded_length(mel_specs_inputs_[:, :, 0])
    label_lengths = get_padded_length(y_true)
    return rnnt_loss(
        acts=y_pred,
        labels=tf.cast(y_true, dtype=tf.int32),
        input_lengths=input_lengths,
        label_lengths=label_lengths
    )

model = create_model(hparams)
model.add_loss(
    rnnt_loss_wrapper(
        y_true=model.inputs[2],
        y_pred=model.outputs[0],
        mel_specs_inputs_=model.inputs[0],
    ),
    inputs=True
)
model.compile(
    optimizer=optimizer
)

However, calling model.fit() throws a ValueError:

ValueError: No gradients provided for any variable: [...]

Is any of the above options supposed to work?


回答1:


Did using lambda function work? (https://www.w3schools.com/python/python_lambda.asp)

loss = lambda x1, x2: rnnt_loss(x1, x2, acts, labels, input_lengths,
                                label_lengths, blank_label=0)

In this way your loss function should be a function accepting parameters x1 and x2, but rnnt_loss can also be aware of acts, labels, input_lengths, label_lengths and blank_label



来源:https://stackoverflow.com/questions/62691100/how-to-use-model-input-in-loss-function

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!