number of parameters in Caffe LENET or Imagenet models

后端 未结 2 983
北海茫月
北海茫月 2020-12-17 21:51

How to calculate number of parameters in a model e.g. LENET for mnist, or ConvNet for imagent model etc. Is there any specific function in caffe that returns or saves numbe

相关标签:
2条回答
  • 2020-12-17 22:07

    I can offer an explicit way to do this via the Matlab interface (make sure the matcaffe is installed first). Basically, you extract set of parameters from each network layer and count them. In Matlab:

    % load the network
    net_model = <path to your *deploy.prototxt file>
    net_weights = <path to your *.caffemodel file>
    phase = 'test';
    test_net = caffe.Net(net_model, net_weights, phase);
    
    % get the list of layers
    layers_list = test_net.layer_names;
    % for those layers which have parameters, count them
    counter = 0;
    for j = 1:length(layers_list),
        if ~isempty(test_net.layers(layers_list{j}).params)
        feat = test_net.layers(layers_list{j}).params(1).get_data();
        counter = counter + numel(feat)
        end
    end
    

    In the end, 'counter' contains the number of parameters.

    0 讨论(0)
  • 2020-12-17 22:29

    Here is a python snippet to compute the number of parameters in a Caffe model:

    import caffe
    caffe.set_mode_cpu()
    import numpy as np
    from numpy import prod, sum
    from pprint import pprint
    
    def print_net_parameters (deploy_file):
        print "Net: " + deploy_file
        net = caffe.Net(deploy_file, caffe.TEST)
        print "Layer-wise parameters: "
        pprint([(k, v[0].data.shape) for k, v in net.params.items()])
        print "Total number of parameters: " + str(sum([prod(v[0].data.shape) for k, v in net.params.items()]))
    
    deploy_file = "/home/ubuntu/deploy.prototxt"
    print_net_parameters(deploy_file)
    
    # Sample output:
    # Net: /home/ubuntu/deploy.prototxt
    # Layer-wise parameters: 
    #[('conv1', (96, 3, 11, 11)),
    # ('conv2', (256, 48, 5, 5)),
    # ('conv3', (384, 256, 3, 3)),
    # ('conv4', (384, 192, 3, 3)),
    # ('conv5', (256, 192, 3, 3)),
    # ('fc6', (4096, 9216)),
    # ('fc7', (4096, 4096)),
    # ('fc8', (819, 4096))]
    # Total number of parameters: 60213280
    

    https://gist.github.com/kaushikpavani/a6a32bd87fdfe5529f0e908ed743f779

    0 讨论(0)
提交回复
热议问题