Keras Model using Tensorflow Distribution for loss fails with batch size > 1

那年仲夏 提交于 2020-01-24 12:16:57

问题


I'm trying to use a distribution from tensorflow_probability to define a custom loss function in Keras. More specifically, I'm trying to build a Mixture Density Network.
My model works on a toy dataset when batch_size = 1 (it learns to predict the correct mixture distribution for y using x). But it "fails" when batch_size > 1 (it predicts the same distribution for all y, ignoring x). This makes me think my problem has to do with batch_shape vs. sample_shape.

To reproduce:

import random

import keras
from keras import backend as K
from keras.layers import Dense, Activation, LSTM, Input, Concatenate, Reshape, concatenate, Flatten, Lambda
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
from keras.models import Sequential, Model
import tensorflow
import tensorflow_probability as tfp
tfd = tfp.distributions

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# generate toy dataset
random.seed(12902)
n_obs = 20000
x = np.random.uniform(size=(n_obs, 4))
df = pd.DataFrame(x, columns = ['x_{0}'.format(i) for i in np.arange(4)])

# 2 latent classes, with noisy assignment based on x_0, x_1, (x_2 and x_3 are noise)
df['latent_class'] = 0
df.loc[df.x_0 + df.x_1 + np.random.normal(scale=.05, size=n_obs) > 1, 'latent_class'] = 1
df.latent_class.value_counts()

# Latent class will determines which mixture distribution we draw from
d0 = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=[0.3, 0.7]),
    components_distribution=tfd.Normal(
      loc=[-1., 1], scale=[0.1, 0.5])) 
d0_samples = d0.sample(sample_shape=(df.latent_class == 0).sum()).numpy()

d1 = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=[0.5, 0.5]),
    components_distribution=tfd.Normal(
      loc=[-2., 2], scale=[0.2, 0.6]))

d1_samples = d1.sample(sample_shape=(df.latent_class == 1).sum()).numpy()

df.loc[df.latent_class == 0, 'y'] = d0_samples
df.loc[df.latent_class == 1, 'y'] = d1_samples

fig, ax = plt.subplots()
bins = np.linspace(-4, 5, 9*4 + 1)
df.y[df.latent_class == 0].hist(ax=ax, bins=bins, label='Class 0', alpha=.4, density=True)
df.y[df.latent_class == 1].hist(ax=ax, bins=bins, label='Class 1', alpha=.4, density=True)
ax.legend();

# mixture density network
N_COMPONENTS = 2  # number of components in the mixture
input_feature_space = 4

flat_input = Input(shape=(input_feature_space,), 
                   batch_shape=(None, input_feature_space), 
                   name='inputs')
x = Dense(6, activation='relu', 
             kernel_initializer='glorot_uniform',
             bias_initializer='ones')(flat_input)
x = Dense(6, activation='relu',
             kernel_initializer='glorot_uniform',
             bias_initializer='ones')(x)
# 3 params per component: weight, loc, scale
output = Dense(N_COMPONENTS*3,
             kernel_initializer='glorot_uniform',
             bias_initializer='ones')(x)

model = Model(inputs=[flat_input],
              outputs=[output])

I suspect the problem is in the next 3 functions:


def get_mixture_coef(output, num_components):
    """
    Extract mixture params from output
    """
    out_pi = output[:, :num_components]
    out_sigma = output[:, num_components:2*num_components]
    out_mu = output[:, 2*num_components:]
    # use softmax to normalize pi into prob distribution
    max_pi = K.max(out_pi, axis=1, keepdims=True)
    out_pi = out_pi - max_pi
    out_pi = K.exp(out_pi)
    normalize_pi = 1 / K.sum(out_pi, axis=1, keepdims=True)
    out_pi = normalize_pi * out_pi
    # use exp to ensure sigma is pos
    out_sigma = K.exp(out_sigma)
    return out_pi, out_sigma, out_mu

def get_lossfunc(out_pi, out_sigma, out_mu, y):
    d0 = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(
            probs=out_pi),
        components_distribution=tfd.Normal(
          loc=out_mu, scale=out_sigma,
        ),
    )
    # I suspect the problem is here
    return -1 * d0.log_prob(y)

def mdn_loss(num_components):
    def loss(y_true, y_pred):
        out_pi, out_sigma, out_mu = get_mixture_coef(y_pred, num_components)
        return get_lossfunc(out_pi, out_sigma, out_mu, y_true)
    return loss

opt = Adam(lr=.001)
model.compile(
    optimizer=opt,
    loss = mdn_loss(N_COMPONENTS),
)

es = EarlyStopping(monitor='val_loss',
                  min_delta=1e-5,
                  patience=5,
                  verbose=1, mode='auto')

validation = .15
validate_idx = np.random.choice(df.index.values, 
                                size=int(validation * df.shape[0]), 
                                replace=False)
train_idx = [i for i in df.index.values if i not in validate_idx]

x_cols = ['x_0', 'x_1', 'x_2', 'x_3']

model.fit(x=df.loc[train_idx, x_cols].values,
          y=df.loc[train_idx, 'y'].values[:, np.newaxis],
          validation_data=(
            df.loc[validate_idx, x_cols].values,
            df.loc[validate_idx, 'y'].values[:, np.newaxis]), 
          # model works when batch_size = 1
          # model fails when batch_size > 1
          epochs=2, batch_size=1, verbose=1, callbacks=[es])

def sample(output, n_samples, num_components):
    """Sample from a mixture distribution parameterized by
    model output."""
    pi, sigma, mu = get_mixture_coef(output, num_components)    
    d0 = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        probs=pi),
    components_distribution=tfd.Normal(
      loc=mu,
      scale=sigma)) 
    return d0.sample(sample_shape=n_samples).numpy()

yhat = model.predict(df.loc[train_idx, x_cols].values)

out_pi, out_sigma, out_mu = get_mixture_coef(yhat, 2)

latent_1_samples = sample(yhat[:1], n_samples=1000, num_components=2)
latent_1_samples = pd.DataFrame({'latent_1_samples': latent_1_samples.ravel()})

fig, ax = plt.subplots()
bins = np.linspace(-4, 5, 9*4 + 1)
latent_1_samples.latent_1_samples.hist(ax=ax, bins=bins, label='Class 1: yHat', alpha=.4, density=True)
df.y[df.latent_class == 0].hist(ax=ax, bins=bins, label='Class 0: True', density=True, histtype='step')
df.y[df.latent_class == 1].hist(ax=ax, bins=bins, label='Class 1: True', density=True, histtype='step')
ax.legend();

Thanks in advance!

Update

I found two ways to solve the problem, guided by this answer. Both solutions point to the fact that Keras is awkwardly broadcasting y to match y_pred:

def get_lossfunc(out_pi, out_sigma, out_mu, y):
    d0 = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(
            probs=out_pi),
        components_distribution=tfd.Normal(
          loc=out_mu, scale=out_sigma,
        ),
    )
    # this also works: 
    # return -1 * d0.log_prob(tensorflow.transpose(y))
    return -1 * d0.log_prob(y[:, 0])

来源:https://stackoverflow.com/questions/58386664/keras-model-using-tensorflow-distribution-for-loss-fails-with-batch-size-1

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!