I am playing with numpy
and digging through documentation and I have come across some magic. Namely I am talking about numpy.where()
:
>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))
How do they achieve internally that you are able to pass something like x > 5
into a method? I guess it has something to do with __gt__
but I am looking for a detailed explanation.
How do they achieve internally that you are able to pass something like x > 5 into a method?
The short answer is that they don't.
Any sort of logical operation on a numpy array returns a boolean array. (i.e. __gt__
, __lt__
, etc all return boolean arrays where the given condition is true).
E.g.
x = np.arange(9).reshape(3,3)
print x > 5
yields:
array([[False, False, False],
[False, False, False],
[ True, True, True]], dtype=bool)
This is the same reason why something like if x > 5:
raises a ValueError if x
is a numpy array. It's an array of True/False values, not a single value.
Furthermore, numpy arrays can be indexed by boolean arrays. E.g. x[x>5]
yields [6 7 8]
, in this case.
Honestly, it's fairly rare that you actually need numpy.where
but it just returns the indicies where a boolean array is True
. Usually you can do what you need with simple boolean indexing.
Old Answer it is kind of confusing. It gives you the LOCATIONS (all of them) of where your statment is true.
so:
>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
99]),)
>>> np.where(a == 90)
(array([90]),)
a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040
I use it as an alternative to list.index(), but it has many other uses as well. I have never used it with 2D arrays.
http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html
New Answer It seems that the person was asking something more fundamental.
The question was how could YOU implement something that allows a function (such as where) to know what was requested.
First note that calling any of the comparison operators do an interesting thing.
a > 1000
array([False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True`, True, True, True, True, True, True, True, True, True], dtype=bool)`
This is done by overloading the "__gt__" method. For instance:
>>> class demo(object):
def __gt__(self, item):
print item
>>> a = demo()
>>> a > 4
4
As you can see, "a > 4" was valid code.
You can get a full list and documentation of all overloaded functions here: http://docs.python.org/reference/datamodel.html
Something that is incredible is how simple it is to do this. ALL operations in python are done in such a way. Saying a > b is equivalent to a.gt(b)!
np.where
returns a tuple of length equal to the dimension of the numpy ndarray on which it is called (in other words ndim
) and each item of tuple is a numpy ndarray of indices of all those values in the initial ndarray for which the condition is True. (Please don't confuse dimension with shape)
For example:
x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))
y is a tuple of length 2 because x.ndim
is 2. The 1st item in tuple contains row numbers of all elements greater than 4 and the 2nd item contains column numbers of all items greater than 4. As you can see, [1,2,2,2] corresponds to row numbers of 5,6,7,8 and [2,0,1,2] corresponds to column numbers of 5,6,7,8
Note that the ndarray is traversed along first dimension(row-wise).
Similarly,
x=np.arange(27).reshape(3,3,3)
np.where(x>4)
will return a tuple of length 3 because x has 3 dimensions.
But wait, there's more to np.where!
when two additional arguments are added to np.where
; it will do a replace operation for all those pairwise row-column combinations which are obtained by the above tuple.
x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
[0, 0, 1],
[1, 1, 1]])
来源:https://stackoverflow.com/questions/5642457/how-does-python-numpy-where-work