Avoid overflow with softplus function in python

前端 未结 4 681
不思量自难忘°
不思量自难忘° 2021-02-14 05:20

I am trying to implement the following softplus function:

log(1 + exp(x))

I\'ve tried it with math/numpy and float64 as data type, but whenever

相关标签:
4条回答
  • 2021-02-14 06:08

    Since for x>30 we have log(1+exp(x)) ~= log(exp(x)) = x, a simple stable implementation is

    def safe_softplus(x, limit=30):
      if x>limit:
        return x
      else:
        return np.log1p(np.exp(x))
    

    In fact | log(1+exp(30)) - 30 | < 1e-10, so this implementation makes errors smaller than 1e-10 and never overflows. In particular for x=1000 the error of this approximation will be much smaller than float64 resolution, so it is impossible to even measure it on the computer.

    0 讨论(0)
  • 2021-02-14 06:08

    i use this code to work in arrays

    def safe_softplus(x):
        inRanges = (x < 100)
        return np.log(1 + np.exp(x*inRanges))*inRanges + x*(1-inRanges)
    
    0 讨论(0)
  • 2021-02-14 06:08

    What I'm currently using (slightly inefficient but clean and vectorized):

    def safe_softplus(x, limit=30):
        return np.where(x>limit, x, np.log1p(np.exp(x)))
    
    0 讨论(0)
  • 2021-02-14 06:11

    There is a relation which one can use:

    log(1+exp(x)) = log(1+exp(x)) - log(exp(x)) + x = log(1+exp(-x)) + x
    

    So a safe implementation, as well as mathematically sound, would be:

    log(1+exp(-abs(x))) + max(x,0)
    

    This works both for math and numpy functions (use e.g.: np.log, np.exp, np.abs, np.maximum).

    0 讨论(0)
提交回复
热议问题