问题
I'm trying to use a distribution from tensorflow_probability to define a custom loss function in Keras. More specifically, I'm trying to build a Mixture Density Network.
My model works on a toy dataset when batch_size = 1 (it learns to predict the correct mixture distribution for y
using x
). But it "fails" when batch_size > 1 (it predicts the same distribution for all y
, ignoring x
). This makes me think my problem has to do with batch_shape vs. sample_shape.
To reproduce:
import random
import keras
from keras import backend as K
from keras.layers import Dense, Activation, LSTM, Input, Concatenate, Reshape, concatenate, Flatten, Lambda
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
from keras.models import Sequential, Model
import tensorflow
import tensorflow_probability as tfp
tfd = tfp.distributions
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# generate toy dataset
random.seed(12902)
n_obs = 20000
x = np.random.uniform(size=(n_obs, 4))
df = pd.DataFrame(x, columns = ['x_{0}'.format(i) for i in np.arange(4)])
# 2 latent classes, with noisy assignment based on x_0, x_1, (x_2 and x_3 are noise)
df['latent_class'] = 0
df.loc[df.x_0 + df.x_1 + np.random.normal(scale=.05, size=n_obs) > 1, 'latent_class'] = 1
df.latent_class.value_counts()
# Latent class will determines which mixture distribution we draw from
d0 = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=[0.3, 0.7]),
components_distribution=tfd.Normal(
loc=[-1., 1], scale=[0.1, 0.5]))
d0_samples = d0.sample(sample_shape=(df.latent_class == 0).sum()).numpy()
d1 = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=[0.5, 0.5]),
components_distribution=tfd.Normal(
loc=[-2., 2], scale=[0.2, 0.6]))
d1_samples = d1.sample(sample_shape=(df.latent_class == 1).sum()).numpy()
df.loc[df.latent_class == 0, 'y'] = d0_samples
df.loc[df.latent_class == 1, 'y'] = d1_samples
fig, ax = plt.subplots()
bins = np.linspace(-4, 5, 9*4 + 1)
df.y[df.latent_class == 0].hist(ax=ax, bins=bins, label='Class 0', alpha=.4, density=True)
df.y[df.latent_class == 1].hist(ax=ax, bins=bins, label='Class 1', alpha=.4, density=True)
ax.legend();
# mixture density network
N_COMPONENTS = 2 # number of components in the mixture
input_feature_space = 4
flat_input = Input(shape=(input_feature_space,),
batch_shape=(None, input_feature_space),
name='inputs')
x = Dense(6, activation='relu',
kernel_initializer='glorot_uniform',
bias_initializer='ones')(flat_input)
x = Dense(6, activation='relu',
kernel_initializer='glorot_uniform',
bias_initializer='ones')(x)
# 3 params per component: weight, loc, scale
output = Dense(N_COMPONENTS*3,
kernel_initializer='glorot_uniform',
bias_initializer='ones')(x)
model = Model(inputs=[flat_input],
outputs=[output])
I suspect the problem is in the next 3 functions:
def get_mixture_coef(output, num_components):
"""
Extract mixture params from output
"""
out_pi = output[:, :num_components]
out_sigma = output[:, num_components:2*num_components]
out_mu = output[:, 2*num_components:]
# use softmax to normalize pi into prob distribution
max_pi = K.max(out_pi, axis=1, keepdims=True)
out_pi = out_pi - max_pi
out_pi = K.exp(out_pi)
normalize_pi = 1 / K.sum(out_pi, axis=1, keepdims=True)
out_pi = normalize_pi * out_pi
# use exp to ensure sigma is pos
out_sigma = K.exp(out_sigma)
return out_pi, out_sigma, out_mu
def get_lossfunc(out_pi, out_sigma, out_mu, y):
d0 = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
probs=out_pi),
components_distribution=tfd.Normal(
loc=out_mu, scale=out_sigma,
),
)
# I suspect the problem is here
return -1 * d0.log_prob(y)
def mdn_loss(num_components):
def loss(y_true, y_pred):
out_pi, out_sigma, out_mu = get_mixture_coef(y_pred, num_components)
return get_lossfunc(out_pi, out_sigma, out_mu, y_true)
return loss
opt = Adam(lr=.001)
model.compile(
optimizer=opt,
loss = mdn_loss(N_COMPONENTS),
)
es = EarlyStopping(monitor='val_loss',
min_delta=1e-5,
patience=5,
verbose=1, mode='auto')
validation = .15
validate_idx = np.random.choice(df.index.values,
size=int(validation * df.shape[0]),
replace=False)
train_idx = [i for i in df.index.values if i not in validate_idx]
x_cols = ['x_0', 'x_1', 'x_2', 'x_3']
model.fit(x=df.loc[train_idx, x_cols].values,
y=df.loc[train_idx, 'y'].values[:, np.newaxis],
validation_data=(
df.loc[validate_idx, x_cols].values,
df.loc[validate_idx, 'y'].values[:, np.newaxis]),
# model works when batch_size = 1
# model fails when batch_size > 1
epochs=2, batch_size=1, verbose=1, callbacks=[es])
def sample(output, n_samples, num_components):
"""Sample from a mixture distribution parameterized by
model output."""
pi, sigma, mu = get_mixture_coef(output, num_components)
d0 = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
probs=pi),
components_distribution=tfd.Normal(
loc=mu,
scale=sigma))
return d0.sample(sample_shape=n_samples).numpy()
yhat = model.predict(df.loc[train_idx, x_cols].values)
out_pi, out_sigma, out_mu = get_mixture_coef(yhat, 2)
latent_1_samples = sample(yhat[:1], n_samples=1000, num_components=2)
latent_1_samples = pd.DataFrame({'latent_1_samples': latent_1_samples.ravel()})
fig, ax = plt.subplots()
bins = np.linspace(-4, 5, 9*4 + 1)
latent_1_samples.latent_1_samples.hist(ax=ax, bins=bins, label='Class 1: yHat', alpha=.4, density=True)
df.y[df.latent_class == 0].hist(ax=ax, bins=bins, label='Class 0: True', density=True, histtype='step')
df.y[df.latent_class == 1].hist(ax=ax, bins=bins, label='Class 1: True', density=True, histtype='step')
ax.legend();
Thanks in advance!
Update
I found two ways to solve the problem, guided by this answer. Both solutions point to the fact that Keras is awkwardly broadcasting y to match y_pred:
def get_lossfunc(out_pi, out_sigma, out_mu, y):
d0 = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
probs=out_pi),
components_distribution=tfd.Normal(
loc=out_mu, scale=out_sigma,
),
)
# this also works:
# return -1 * d0.log_prob(tensorflow.transpose(y))
return -1 * d0.log_prob(y[:, 0])
来源:https://stackoverflow.com/questions/58386664/keras-model-using-tensorflow-distribution-for-loss-fails-with-batch-size-1