Variable length output in keras

前端 未结 1 902
情书的邮戳
情书的邮戳 2021-01-06 06:16

I\'m trying to create an autoencoder in keras with bucketing where the input and the output have different time steps.

model = Sequential()

#encoder
model.a         


        
相关标签:
1条回答
  • If you mean "inputs with variable lengths" and "outputs with the same lengths as the inputs", you can do this:

    Warning: this solution must work with batch size = 1
    You will need to create an external loop and pass each sample as a numpy array with the exact length
    You cannot use masking in this solution, and the right output depends on the correct length of the input

    This is a working code using Keras + Tensorflow:

    Imports:

    from keras.layers import *
    from keras.models import Model
    import numpy as np
    import keras.backend as K
    from keras.utils.np_utils import to_categorical
    

    Custom functions to use in Lambda layers:

    #this function gets the length from the original input 
    #and stores it in the final output of the encoder
    def storeLength(x):
        inputTensor = x[0]
        storeInto = x[1] #the final output
    
        length = K.shape(inputTensor)[1]
        length = K.cast(length,K.floatx())
        length = K.reshape(length,(1,1))
    
        #will put length as the first element in the final output
        return K.concatenate([length,storeInto])
    
    
    #this function expands the length of the input in the decoder
    def expandLength(x):
        #lenght is the first element in the encoded input
        length = K.cast(x[0,0],'int32') #or int64 if necessary
    
        #the remaining elements are the actual data to be decoded
        data = x[:,1:]
    
        #a tensor with shape (length,)
        length = K.ones_like(K.arange(0,length))
    
        #make both length tensor and data tensor 3D and with paired dimensions 
        length = K.cast(K.reshape(length,(1,-1,1)),K.floatx())
        data = K.reshape(data,(1,1,-1))
    
        #this automatically repeats the elements based on the paired shapes
        return data*length 
    

    Creating the models:

    I assumed the output is equal to the input, but since you're using an Embedding, I made "num_classes" equal to the number of words.

    For this solution, we use a branching, thus I had to use the functional API Model. Which will be way better later, because you will want to train with autoencoder.train_on_batch and then just encode with encoder.predict() or just decode with decoder.predict().

    vocab_size = 100
    embedding_size = 7
    num_class=vocab_size
    hidden_size = 3
    
    #encoder
    inputs = Input(batch_shape = (1,None))
    outputs = Embedding(vocab_size, embedding_size)(inputs)
    outputs = LSTM(units=hidden_size, return_sequences=False)(outputs)
    outputs = Lambda(storeLength)([inputs,outputs])
    encoder = Model(inputs,outputs)
    
    #decoder
    inputs = Input(batch_shape=(1,hidden_size+1))
    outputs = Lambda(expandLength)(inputs)
    outputs = LSTM(units=hidden_size, return_sequences=True)(outputs)
    outputs = TimeDistributed(Dense(num_class, activation='softmax'))(outputs)
    decoder = Model(inputs,outputs)
    
    #autoencoder
    inputs = Input(batch_shape=(1,None))
    outputs = encoder(inputs)
    outputs = decoder(outputs)
    autoencoder = Model(inputs,outputs)
    
    #see each model's shapes 
    encoder.summary()
    decoder.summary()
    autoencoder.summary()
    

    Just an example with fake data and the method that should be used for training:

    inputData = []
    outputData = []
    for i in range(7,10):
        inp = np.arange(i).reshape((1,i))
        inputData.append(inp)
    
        outputData.append(to_categorical(inp,num_class))
    
    autoencoder.compile(loss='mse',optimizer='adam')
    
    for epoch in range(1):
        for inputSample,outputSample in zip(inputData,outputData):
    
            print(inputSample.shape,outputSample.shape)
            autoencoder.train_on_batch(inputSample,outputSample)
    
    for inputSample in inputData:
        print(autoencoder.predict(inputSample).shape)
    
    0 讨论(0)
提交回复
热议问题