how can I make a numpy function that accepts a numpy array, an iterable, or a scalar?

前端 未结 4 1036
死守一世寂寞
死守一世寂寞 2021-02-19 20:50

Suppose I have this:

def incrementElements(x):
   return x+1

but I want to modify it so that it can take either a numpy array, an iterable, or

4条回答
  •  别跟我提以往
    2021-02-19 21:50

    You could try

    def incrementElements(x):
        x = np.asarray(x)
        return x+1
    

    np.asarray(x) is the equivalent of np.array(x, copy=False), meaning that a scalar or an iterable will be transformed to a ndarray, but if x is already a ndarray, its data will not be copied.

    If you pass a scalar and want a ndarray as output (not a scalar), you can use:

    def incrementElements(x):
        x = np.array(x, copy=False, ndmin=1)
        return x
    

    The ndmin=1 argument will force the array to have at least one dimension. Use ndmin=2 for at least 2 dimensions, and so forth. You can also use its equivalent np.atleast_1d (or np.atleast_2d for the 2D version...)

提交回复
热议问题