Keras Conv2D and input channels

前端 未结 3 1540
面向向阳花
面向向阳花 2021-01-30 16:32

The Keras layer documentation specifies the input and output sizes for convolutional layers: https://keras.io/layers/convolutional/

Input shape: (samples, channels

3条回答
  •  不思量自难忘°
    2021-01-30 17:12

    I also needed to convince myself so I ran a simple example with a 3×3 RGB image.

    # red    # green        # blue
    1 1 1    100 100 100    10000 10000 10000
    1 1 1    100 100 100    10000 10000 10000    
    1 1 1    100 100 100    10000 10000 10000
    

    The filter is initialised to ones:

    1 1
    1 1
    

    I have also set the convolution to have these properties:

    • no padding
    • strides = 1
    • relu activation function
    • bias initialised to 0

    We would expect the (aggregated) output to be:

    40404 40404
    40404 40404
    

    Also, from the picture above, the no. of parameters is

    3 separate filters (one for each channel) × 4 weights + 1 (bias, not shown) = 13 parameters


    Here's the code.

    Import modules:

    import numpy as np
    from keras.layers import Input, Conv2D
    from keras.models import Model
    

    Create the red, green and blue channels:

    red   = np.array([1]*9).reshape((3,3))
    green = np.array([100]*9).reshape((3,3))
    blue  = np.array([10000]*9).reshape((3,3))
    

    Stack the channels to form an RGB image:

    img = np.stack([red, green, blue], axis=-1)
    img = np.expand_dims(img, axis=0)
    

    Create a model that just does a Conv2D convolution:

    inputs = Input((3,3,3))
    conv = Conv2D(filters=1, 
                  strides=1, 
                  padding='valid', 
                  activation='relu',
                  kernel_size=2, 
                  kernel_initializer='ones', 
                  bias_initializer='zeros', )(inputs)
    model = Model(inputs,conv)
    

    Input the image in the model:

    model.predict(img)
    # array([[[[40404.],
    #          [40404.]],
    
    #         [[40404.],
    #          [40404.]]]], dtype=float32)
    

    Run a summary to get the number of params:

    model.summary()
    

提交回复
热议问题