How does python numpy.where() work?

前端 未结 3 1504
礼貌的吻别
礼貌的吻别 2020-12-04 14:54

I am playing with numpy and digging through documentation and I have come across some magic. Namely I am talking about numpy.where():



        
相关标签:
3条回答
  • 2020-12-04 15:25

    How do they achieve internally that you are able to pass something like x > 5 into a method?

    The short answer is that they don't.

    Any sort of logical operation on a numpy array returns a boolean array. (i.e. __gt__, __lt__, etc all return boolean arrays where the given condition is true).

    E.g.

    x = np.arange(9).reshape(3,3)
    print x > 5
    

    yields:

    array([[False, False, False],
           [False, False, False],
           [ True,  True,  True]], dtype=bool)
    

    This is the same reason why something like if x > 5: raises a ValueError if x is a numpy array. It's an array of True/False values, not a single value.

    Furthermore, numpy arrays can be indexed by boolean arrays. E.g. x[x>5] yields [6 7 8], in this case.

    Honestly, it's fairly rare that you actually need numpy.where but it just returns the indicies where a boolean array is True. Usually you can do what you need with simple boolean indexing.

    0 讨论(0)
  • 2020-12-04 15:33

    Old Answer it is kind of confusing. It gives you the LOCATIONS (all of them) of where your statment is true.

    so:

    >>> a = np.arange(100)
    >>> np.where(a > 30)
    (array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
           48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
           65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
           82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
           99]),)
    >>> np.where(a == 90)
    (array([90]),)
    
    a = a*40
    >>> np.where(a > 1000)
    (array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
           43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
           60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
           77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
           94, 95, 96, 97, 98, 99]),)
    >>> a[25]
    1000
    >>> a[26]
    1040
    

    I use it as an alternative to list.index(), but it has many other uses as well. I have never used it with 2D arrays.

    http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

    New Answer It seems that the person was asking something more fundamental.

    The question was how could YOU implement something that allows a function (such as where) to know what was requested.

    First note that calling any of the comparison operators do an interesting thing.

    a > 1000
    array([False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`
    

    This is done by overloading the "__gt__" method. For instance:

    >>> class demo(object):
        def __gt__(self, item):
            print item
    
    
    >>> a = demo()
    >>> a > 4
    4
    

    As you can see, "a > 4" was valid code.

    You can get a full list and documentation of all overloaded functions here: http://docs.python.org/reference/datamodel.html

    Something that is incredible is how simple it is to do this. ALL operations in python are done in such a way. Saying a > b is equivalent to a.gt(b)!

    0 讨论(0)
  • 2020-12-04 15:34

    np.where returns a tuple of length equal to the dimension of the numpy ndarray on which it is called (in other words ndim) and each item of tuple is a numpy ndarray of indices of all those values in the initial ndarray for which the condition is True. (Please don't confuse dimension with shape)

    For example:

    x=np.arange(9).reshape(3,3)
    print(x)
    array([[0, 1, 2],
          [3, 4, 5],
          [6, 7, 8]])
    y = np.where(x>4)
    print(y)
    array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))
    


    y is a tuple of length 2 because x.ndim is 2. The 1st item in tuple contains row numbers of all elements greater than 4 and the 2nd item contains column numbers of all items greater than 4. As you can see, [1,2,2,2] corresponds to row numbers of 5,6,7,8 and [2,0,1,2] corresponds to column numbers of 5,6,7,8 Note that the ndarray is traversed along first dimension(row-wise).

    Similarly,

    x=np.arange(27).reshape(3,3,3)
    np.where(x>4)
    


    will return a tuple of length 3 because x has 3 dimensions.

    But wait, there's more to np.where!

    when two additional arguments are added to np.where; it will do a replace operation for all those pairwise row-column combinations which are obtained by the above tuple.

    x=np.arange(9).reshape(3,3)
    y = np.where(x>4, 1, 0)
    print(y)
    array([[0, 0, 0],
       [0, 0, 1],
       [1, 1, 1]])
    
    0 讨论(0)
提交回复
热议问题