How to implement the ReLU function in Numpy

后端 未结 9 1005
野性不改
野性不改 2020-12-02 09:53

I want to make a simple neural network which uses the ReLU function. Can someone give me a clue of how can I implement the function using numpy.

相关标签:
9条回答
  • 2020-12-02 10:21

    ReLU(x) also is equal to (x+abs(x))/2

    0 讨论(0)
  • 2020-12-02 10:22

    Richard Möhn's comparison is not fair.
    As Andrea Di Biagio's comment, the in-place method np.maximum(x, 0, x) will modify x at the first loop.

    So here is my benchmark:

    import numpy as np
    
    def baseline():
        x = np.random.random((5000, 5000)) - 0.5
        return x
    
    def relu_mul():
        x = np.random.random((5000, 5000)) - 0.5
        out = x * (x > 0)
        return out
    
    def relu_max():
        x = np.random.random((5000, 5000)) - 0.5
        out = np.maximum(x, 0)
        return out
    
    def relu_max_inplace():
        x = np.random.random((5000, 5000)) - 0.5
        np.maximum(x, 0, x)
        return x 
    

    Timing it:

    print("baseline:")
    %timeit -n10 baseline()
    print("multiplication method:")
    %timeit -n10 relu_mul()
    print("max method:")
    %timeit -n10 relu_max()
    print("max inplace method:")
    %timeit -n10 relu_max_inplace()
    

    Get the results:

    baseline:
    10 loops, best of 3: 425 ms per loop
    multiplication method:
    10 loops, best of 3: 596 ms per loop
    max method:
    10 loops, best of 3: 682 ms per loop
    max inplace method:
    10 loops, best of 3: 602 ms per loop
    

    In-place maximum method is only a bit faster than the maximum method, and it may because it omits the variable assignment for 'out'. And it's still slower than the multiplication method.
    And since you're implementing the ReLU func. You may have to save the 'x' for backprop through relu. E.g.:

    def relu_backward(dout, cache):
        x = cache
        dx = np.where(x > 0, dout, 0)
        return dx
    

    So i recommend you to use multiplication method.

    0 讨论(0)
  • 2020-12-02 10:30

    You can do it in much easier way:

    def ReLU(x):
        return x * (x > 0)
    
    def dReLU(x):
        return 1. * (x > 0)
    
    0 讨论(0)
提交回复
热议问题