I have a multi input Keras model. Here the inputs:
[,
try
train_x_list = [tf.squeeze(tx) for tx in tf.split(train_x, num_or_size_splits=train_x.shape[0], axis=0)]
it will produce a list of tensors with training data split along dimension 0. Then use your second solution, feeding the list to fit().
fit()