Error in element wise weighted averaging between 2 layers in keras cnn

前端 未结 1 1765
梦谈多话
梦谈多话 2021-01-28 04:03

I am getting error in element wise weighted averaging between 2 layers in cnn My base model is

model_base = Sequential()
# Conv Layer 1
model_base.add(layers.Sepa         


        
相关标签:
1条回答
  • 2021-01-28 04:40

    given your base_model this the correct way to build the code block below...

    l1 = model_base.layers[2].output
    l1 = GlobalAveragePooling2D()(l1) 
    c2 = model_base.layers[4].output
    c2 = GlobalAveragePooling2D()(c2) 
    c3 = model_base.layers[6].output
    
    c = c3.shape[-1] ### this is important for the dimesionality
    l1 = Dense(c)(l1)
    c2 = Dense(c)(c2) 
    
    c13 = Lambda(lambda lam: K.squeeze(K.map_fn(lambda xy: K.dot(xy[0], xy[1]), 
                                                elems=(lam[0], K.expand_dims(lam[1], -1)), dtype='float32'), 3), name='cdp1')([c3, l1])  # batch*x*y
    
    c23 = Lambda(lambda lam: K.squeeze(K.map_fn(lambda xy: K.dot(xy[0], xy[1]), 
                                                elems=(lam[0], K.expand_dims(lam[1], -1)), dtype='float32'), 3), name='cdp2')([c3, c2])  # batch*x*y
    
    flatc13 = Flatten(name='flatc1')(c13)  # batch*xy
    flatc23 = Flatten(name='flatc2')(c23)  # batch*xy
    
    a1 = Activation('softmax', name='softmax1')(flatc13) # batch*xy
    a2 = Activation('softmax', name='softmax2')(flatc23) # batch*xy
    
    reshaped = Reshape((-1,c), name='reshape1')(c3)  # batch*xy*c
    
    g1 = Lambda(lambda lam: K.squeeze(K.batch_dot(K.expand_dims(lam[0], 1), lam[1]), 1), 
                name='g1')([a1,reshaped])  # batch*c
    g2 = Lambda(lambda lam: K.squeeze(K.batch_dot(K.expand_dims(lam[0], 1), lam[1]), 1), 
                name='g2')([a2,reshaped])  # batch*c
    

    pay attention to the dimensionality (in your case you can't operate with 512 but with 256, this is handled automatically by the c variable). pay attention also to the order of the layer used in the Lambda operations (for example in c13 it's ([c3, l1]) and not ([l1, c3]))

    here the running notebook: https://colab.research.google.com/drive/1m0pB5GlYRtIsOnHUTz6LxRQblcvtVU3Y?usp=sharing

    0 讨论(0)
提交回复
热议问题