How can an almost arbitrary plane in a 3D dataset be plotted by matplotlib?

前端 未结 5 1549
慢半拍i
慢半拍i 2021-01-18 15:11

There is an array containing 3D data of shape e.g. (64,64,64), how do you plot a plane given by a point and a normal (similar to hkl planes in crystallography), through this

5条回答
  •  滥情空心
    2021-01-18 15:37

    The other answers here do not appear to be very efficient with explicit loops over pixels or using scipy.interpolate.griddata, which is designed for unstructured input data. Here is an efficient (vectorized) and generic solution.

    There is a pure numpy implementation (for nearest-neighbor "interpolation") and one for linear interpolation, which delegates the interpolation to scipy.ndimage.map_coordinates. (The latter function probably didn't exist in 2013, when this question was asked.)

    import numpy as np
    from scipy.ndimage import map_coordinates
         
    def slice_datacube(cube, center, eXY, mXY, fill=np.nan, interp=True):
        """Get a 2D slice from a 3-D array.
        
        Copyright: Han-Kwang Nienhuys, 2020.
        License: any of CC-BY-SA, CC-BY, BSD, GPL, LGPL
        Reference: https://stackoverflow.com/a/62733930/6228891
        
        Parameters:
        
        - cube: 3D array, assumed shape (nx, ny, nz).
        - center: shape (3,) with coordinates of center.
          can be float. 
        - eXY: unit vectors, shape (2, 3) - for X and Y axes of the slice.
          (unit vectors must be orthogonal; normalization is optional).
        - mXY: size tuple of output array (mX, mY) - int.
        - fill: value to use for out-of-range points.
        - interp: whether to interpolate (rather than using 'nearest')
        
        Return:
            
        - slice: array, shape (mX, mY).
        """
        
        center = np.array(center, dtype=float)
        assert center.shape == (3,)
        
        eXY = np.array(eXY)/np.linalg.norm(eXY, axis=1)[:, np.newaxis]
        if not np.isclose(eXY[0] @ eXY[1], 0, atol=1e-6):
            raise ValueError(f'eX and eY not orthogonal.')
    
        # R: rotation matrix: data_coords = center + R @ slice_coords
        eZ = np.cross(eXY[0], eXY[1])
        R = np.array([eXY[0], eXY[1], eZ], dtype=np.float32).T
        
        # setup slice points P with coordinates (X, Y, 0)
        mX, mY = int(mXY[0]), int(mXY[1])    
        Xs = np.arange(0.5-mX/2, 0.5+mX/2)
        Ys = np.arange(0.5-mY/2, 0.5+mY/2)
        PP = np.zeros((3, mX, mY), dtype=np.float32)
        PP[0, :, :] = Xs.reshape(mX, 1)
        PP[1, :, :] = Ys.reshape(1, mY)
            
        # Transform to data coordinates (x, y, z) - idx.shape == (3, mX, mY)
        if interp:
            idx = np.einsum('il,ljk->ijk', R, PP) + center.reshape(3, 1, 1)
            slice = map_coordinates(cube, idx, order=1, mode='constant', cval=fill)
        else:
            idx = np.einsum('il,ljk->ijk', R, PP) + (0.5 + center.reshape(3, 1, 1))
            idx = idx.astype(np.int16)
            # Find out which coordinates are out of range - shape (mX, mY)
            badpoints = np.any([
                idx[0, :, :] < 0,
                idx[0, :, :] >= cube.shape[0], 
                idx[1, :, :] < 0,
                idx[1, :, :] >= cube.shape[1], 
                idx[2, :, :] < 0,
                idx[2, :, :] >= cube.shape[2], 
                ], axis=0)
            
            idx[:, badpoints] = 0
            slice = cube[idx[0], idx[1], idx[2]]
            slice[badpoints] = fill
            
        return slice
        
    # Demonstration
    nx, ny, nz = 50, 70, 100
    cube = np.full((nx, ny, nz), np.float32(1))
    
    cube[nx//4:nx*3//4, :, :] += 1
    cube[:, ny//2:ny*3//4, :] += 3
    cube[:, :, nz//4:nz//2] += 7
    cube[nx//3-2:nx//3+2, ny//2-2:ny//2+2, :] = 0 # black dot
         
    Rz, Rx = np.pi/6, np.pi/4 # rotation angles around z and x
    cz, sz = np.cos(Rz), np.sin(Rz)
    cx, sx = np.cos(Rx), np.sin(Rx)
    
    Rmz = np.array([[cz, -sz, 0], [sz, cz, 0], [0, 0, 1]])
    Rmx = np.array([[1, 0, 0], [0, cx, -sx], [0, sx, cx]])
    eXY = (Rmx @ Rmz).T[:2]
      
    slice = slice_datacube(
        cube, 
        center=[nx/3, ny/2, nz*0.7], 
        eXY=eXY,
        mXY=[80, 90],
        fill=np.nan,
        interp=False
        )
    
    import matplotlib.pyplot as plt
    plt.close('all')
    plt.imshow(slice.T) # imshow expects shape (mY, mX)
    plt.colorbar()
    

    Output (for interp=False):

    For this test case (50x70x100 datacube, 80x90 slice size) the run time is 376 µs (interp=False) and 550 µs (interp=True) on my laptop.

提交回复
热议问题