np.ndarray with Periodic Boundary conditions

前端 未结 1 732
忘掉有多难
忘掉有多难 2021-02-06 12:26

Problem

To impose np.ndarray periodic boundary conditions as laid out below

Details

  • Wrap the indexing of a python np.ndarray
相关标签:
1条回答
  • 2021-02-06 13:25

    Wrap function

    A simple function can be written with the mod function, % in basic python and generalised to operate on an n-dimensional tuple given a specific shape.

    def latticeWrapIdx(index, lattice_shape):
        """returns periodic lattice index 
        for a given iterable index
        
        Required Inputs:
            index :: iterable :: one integer for each axis
            lattice_shape :: the shape of the lattice to index to
        """
        if not hasattr(index, '__iter__'): return index         # handle integer slices
        if len(index) != len(lattice_shape): return index  # must reference a scalar
        if any(type(i) == slice for i in index): return index   # slices not supported
        if len(index) == len(lattice_shape):               # periodic indexing of scalars
            mod_index = tuple(( (i%s + s)%s for i,s in zip(index, lattice_shape)))
            return mod_index
        raise ValueError('Unexpected index: {}'.format(index))
    

    This is tested as:

    arr = np.array([[ 11.,  12.,  13.,  14.],
                    [ 21.,  22.,  23.,  24.],
                    [ 31.,  32.,  33.,  34.],
                    [ 41.,  42.,  43.,  44.]])
    test_vals = [[(1,1), 22.], [(3,3), 44.], [( 4, 4), 11.], # [index, expected value]
                 [(3,4), 41.], [(4,3), 14.], [(10,10), 33.]]
    
    passed = all([arr[latticeWrapIdx(idx, (4,4))] == act for idx, act in test_vals])
    print "Iterating test values. Result: {}".format(passed)
    

    and gives the output of,

    Iterating test values. Result: True
    

    Subclassing Numpy

    The wrapping function can be incorporated into a subclassed np.ndarray as described here:

    class Periodic_Lattice(np.ndarray):
        """Creates an n-dimensional ring that joins on boundaries w/ numpy
        
        Required Inputs
            array :: np.array :: n-dim numpy array to use wrap with
        
        Only currently supports single point selections wrapped around the boundary
        """
        def __new__(cls, input_array, lattice_spacing=None):
            """__new__ is called by numpy when and explicit constructor is used:
            obj = MySubClass(params) otherwise we must rely on __array_finalize
             """
            # Input array is an already formed ndarray instance
            # We first cast to be our class type
            obj = np.asarray(input_array).view(cls)
            
            # add the new attribute to the created instance
            obj.lattice_shape = input_array.shape
            obj.lattice_dim = len(input_array.shape)
            obj.lattice_spacing = lattice_spacing
            
            # Finally, we must return the newly created object:
            return obj
        
        def __getitem__(self, index):
            index = self.latticeWrapIdx(index)
            return super(Periodic_Lattice, self).__getitem__(index)
        
        def __setitem__(self, index, item):
            index = self.latticeWrapIdx(index)
            return super(Periodic_Lattice, self).__setitem__(index, item)
        
        def __array_finalize__(self, obj):
            """ ndarray.__new__ passes __array_finalize__ the new object, 
            of our own class (self) as well as the object from which the view has been taken (obj). 
            See
            http://docs.scipy.org/doc/numpy/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray
            for more info
            """
            # ``self`` is a new object resulting from
            # ndarray.__new__(Periodic_Lattice, ...), therefore it only has
            # attributes that the ndarray.__new__ constructor gave it -
            # i.e. those of a standard ndarray.
            #
            # We could have got to the ndarray.__new__ call in 3 ways:
            # From an explicit constructor - e.g. Periodic_Lattice():
            #   1. obj is None
            #       (we're in the middle of the Periodic_Lattice.__new__
            #       constructor, and self.info will be set when we return to
            #       Periodic_Lattice.__new__)
            if obj is None: return
            #   2. From view casting - e.g arr.view(Periodic_Lattice):
            #       obj is arr
            #       (type(obj) can be Periodic_Lattice)
            #   3. From new-from-template - e.g lattice[:3]
            #       type(obj) is Periodic_Lattice
            # 
            # Note that it is here, rather than in the __new__ method,
            # that we set the default value for 'spacing', because this
            # method sees all creation of default objects - with the
            # Periodic_Lattice.__new__ constructor, but also with
            # arr.view(Periodic_Lattice).
            #
            # These are in effect the default values from these operations
            self.lattice_shape = getattr(obj, 'lattice_shape', obj.shape)
            self.lattice_dim = getattr(obj, 'lattice_dim', len(obj.shape))
            self.lattice_spacing = getattr(obj, 'lattice_spacing', None)
            pass
        
        def latticeWrapIdx(self, index):
            """returns periodic lattice index 
            for a given iterable index
            
            Required Inputs:
                index :: iterable :: one integer for each axis
            
            This is NOT compatible with slicing
            """
            if not hasattr(index, '__iter__'): return index         # handle integer slices
            if len(index) != len(self.lattice_shape): return index  # must reference a scalar
            if any(type(i) == slice for i in index): return index   # slices not supported
            if len(index) == len(self.lattice_shape):               # periodic indexing of scalars
                mod_index = tuple(( (i%s + s)%s for i,s in zip(index, self.lattice_shape)))
                return mod_index
            raise ValueError('Unexpected index: {}'.format(index))
    

    Testing demonstrates the lattice overloads correctly,

    arr = np.array([[ 11.,  12.,  13.,  14.],
                    [ 21.,  22.,  23.,  24.],
                    [ 31.,  32.,  33.,  34.],
                    [ 41.,  42.,  43.,  44.]])
    test_vals = [[(1,1), 22.], [(3,3), 44.], [( 4, 4), 11.], # [index, expected value]
                 [(3,4), 41.], [(4,3), 14.], [(10,10), 33.]]
    
    periodic_arr  = Periodic_Lattice(arr)
    passed = (periodic_arr == arr).all()
    passed *= all([periodic_arr[idx] == act for idx, act in test_vals])
    print "Iterating test values. Result: {}".format(passed)
    

    and gives the output of,

    Iterating test values. Result: True
    

    Finally, using the code provided in the initial problem we obtain:

    True
    error
    error
    
    0 讨论(0)
提交回复
热议问题