LabelPropagation - How to avoid division by zero?

后端 未结 1 1254
Happy的楠姐
Happy的楠姐 2021-02-18 23:45

When using LabelPropagation, I often run into this warning (imho it should be an error because it completely fails the propagation):

/usr/local/lib/python

相关标签:
1条回答
  • 2021-02-19 00:31

    Basically you're doing a softmax function, right?

    The general way to prevent softmax from over/underflowing is (from here)

    # Instead of this . . . 
    def softmax(x, axis = 0):
        return np.exp(x) / np.sum(np.exp(x), axis = axis, keepdims = True)
    
    # Do this
    def softmax(x, axis = 0):
        e_x = np.exp(x - np.max(x, axis = axis, keepdims = True))
        return e_x / e_x.sum(axis, keepdims = True)
    

    This bounds e_x between 0 and 1, and assures one value of e_x will always be 1 (namely the element np.argmax(x)). This prevents overflow and underflow (when np.exp(x.max()) is either bigger or smaller than float64 can handle).

    In this case, as you can't change the algorithm, I would take the input D and make D_ = D - D.min() as this should be numerically equivalent to the above, as W.max() should be -gamma * D.min() (as you're just flipping the sign). The do your algorithm with regards to D_

    EDIT:

    As recommended by @PaulBrodersen below, you can build a "safe" rbf kernel based on the sklearn implementation here:

    def rbf_kernel_safe(X, Y=None, gamma=None): 
    
          X, Y = sklearn.metrics.pairwise.check_pairwise_arrays(X, Y) 
          if gamma is None: 
              gamma = 1.0 / X.shape[1] 
    
          K = sklearn.metrics.pairwise.euclidean_distances(X, Y, squared=True) 
          K *= -gamma 
          K -= K.max()
          np.exp(K, K)    # exponentiate K in-place 
          return K 
    

    And then use it in your propagation

    LabelPropagation(kernel = rbf_kernel_safe, tol = 0.01, gamma = 20).fit(X, Y)
    

    Unfortunately I only have v0.18, which doesn't accept user-defined kernel functions for LabelPropagation, so I can't test it.

    EDIT2:

    Checking your source for why you have such large gamma values makes me wonder if you are using gamma = D.min()/3, which would be incorrect. The definition is sigma = D.min()/3 while the definition of sigma in w is

    w = exp(-d**2/sigma**2)  # Equation (1)
    

    which would make the correct gamma value 1/sigma**2 or 9/D.min()**2

    0 讨论(0)
提交回复
热议问题