how can I make a numpy function that accepts a numpy array, an iterable, or a scalar?

前端 未结 4 1060
死守一世寂寞
死守一世寂寞 2021-02-19 20:50

Suppose I have this:

def incrementElements(x):
   return x+1

but I want to modify it so that it can take either a numpy array, an iterable, or

4条回答
  •  长发绾君心
    2021-02-19 21:30

    Pierre GM's answer is great so long as your function exclusively uses ufuncs (or something similar) to implicitly loop over the input values. If your function needs to iterate over the inputs, then np.asarray doesn't do enough, because you can't iterate over a NumPy scalar:

    import numpy as np
    
    x = np.asarray(1)
    for xval in x:
        print(np.exp(xval))
    
    Traceback (most recent call last):
      File "Untitled 2.py", line 4, in 
        for xval in x:
    TypeError: iteration over a 0-d array
    

    If your function needs to iterate over the input, something like the following will work, using np.atleast_1d and np.squeeze (see Array manipulation routines — NumPy Manual). I included an aaout ("Always Array OUT") arg so you can specify whether you want scalar inputs to produce single-element array outputs; it could be dropped if not needed:

    def scalar_or_iter_in(x, aaout=False):
        """
        Gather function evaluations over scalar or iterable `x` values.
    
        aaout :: boolean
            "Always array output" flag:  If True, scalar input produces
            a 1-D, single-element array output.  If false, scalar input
            produces scalar output.
        """
        x = np.asarray(x)
        scalar_in = x.ndim==0
    
        # Could use np.array instead of np.atleast_1d, as follows:
        # xvals = np.array(x, copy=False, ndmin=1)
        xvals = np.atleast_1d(x)
        y = np.empty_like(xvals, dtype=float)  # dtype in case input is ints
        for i, xx in enumerate(xvals):
            y[i] = np.exp(xx)  # YOUR OPERATIONS HERE!
    
        if scalar_in and not aaout:
            return np.squeeze(y)
        else:
            return y
    
    
    print(scalar_or_iter_in(1.))
    print(scalar_or_iter_in(1., aaout=True))
    print(scalar_or_iter_in([1,2,3]))
    
    
    2.718281828459045
    [2.71828183]
    [ 2.71828183  7.3890561  20.08553692]
    

    Of course, for exponentiation you should not explicitly iterate as here, but a more complex operation may not be expressible using NumPy ufuncs. If you do not need to iterate, but want similar control over whether scalar inputs produce single-element array outputs, the middle of the function could be simpler, but the return has to handle the np.atleast_1d:

    def scalar_or_iter_in(x, aaout=False):
        """
        Gather function evaluations over scalar or iterable `x` values.
    
        aaout :: boolean
            "Always array output" flag:  If True, scalar input produces
            a 1-D, single-element array output.  If false, scalar input
            produces scalar output.
        """
        x = np.asarray(x)
        scalar_in = x.ndim==0
    
        y = np.exp(x)  # YOUR OPERATIONS HERE!
    
        if scalar_in and not aaout:
            return np.squeeze(y)
        else:
            return np.atleast_1d(y)
    

    I suspect in most cases the aaout flag is not necessary, and that you'd always want scalar outputs with scalar inputs. In such cases, the return should just be:

        if scalar_in:
            return np.squeeze(y)
        else:
            return y
    

提交回复
热议问题