Test if an array is broadcastable to a shape?

后端 未结 5 1853
灰色年华
灰色年华 2021-01-21 16:39

What is the best way to test whether an array can be broadcast to a given shape?

The \"pythonic\" approach of trying doesn\'t work for my case, because the

相关标签:
5条回答
  • 2021-01-21 17:01

    If you just want to avoid materializing an array with a given shape, you can use as_strided:

    import numpy as np
    from numpy.lib.stride_tricks import as_strided
    
    def is_broadcastable(shp1, shp2):
        x = np.array([1])
        a = as_strided(x, shape=shp1, strides=[0] * len(shp1))
        b = as_strided(x, shape=shp2, strides=[0] * len(shp2))
        try:
            c = np.broadcast_arrays(a, b)
            return True
        except ValueError:
            return False
    
    is_broadcastable((1000, 1000, 1000), (1000, 1, 1000))  # True
    is_broadcastable((1000, 1000, 1000), (3,))  # False
    

    This is memory efficient, since a and b are both backed by a single record

    0 讨论(0)
  • 2021-01-21 17:03

    You could use np.broadcast. For example:

    In [47]: x = np.ones([2,2,2])
    
    In [48]: y = np.ones([2,3])
    
    In [49]: try:
       ....:     b = np.broadcast(x, y)
       ....:     print "Result has shape", b.shape
       ....: except ValueError:
       ....:     print "Not compatible for broadcasting"
       ....:     
    Not compatible for broadcasting
    
    In [50]: y = np.ones([2,2])
    
    In [51]: try:
       ....:     b = np.broadcast(x, y)
       ....:     print "Result has shape", b.shape
       ....: except ValueError:
       ....:     print "Not compatible for broadcasting"
       ....:
    Result has shape (2, 2, 2)
    

    For your implementation of lazy evaluation, you might also find np.broadcast_arrays useful.

    0 讨论(0)
  • 2021-01-21 17:04

    To generalize this to arbitrarily many shapes, you can do so as follows:

    def is_broadcast_compatible(*shapes):
        if len(shapes) < 2:
            return True
        else:
            for dim in zip(*[shape[::-1] for shape in shapes]):
                if len(set(dim).union({1})) <= 2:
                    pass
                else:
                    return False
            return True
    

    The corresponding test case is as follows:

    import unittest
    
    
    class TestBroadcastCompatibility(unittest.TestCase):
        def check_true(self, *shapes):
            self.assertTrue(is_broadcast_compatible(*shapes), msg=shapes)
    
        def check_false(self, *shapes):
            self.assertFalse(is_broadcast_compatible(*shapes), msg=shapes)
    
        def test(self):
            self.check_true((1, 2, 3), (1, 2, 3))
            self.check_true((3, 1, 3), (3, 3, 3))
            self.check_true((1,), (2,), (2,))
    
            self.check_false((1, 2, 3), (1, 2, 2))
            self.check_false((1, 2, 3), (1, 2, 3, 4))
            self.check_false((1,), (2,), (3,))
    
    0 讨论(0)
  • 2021-01-21 17:10

    I really think you guys are over thinking this, why not just keep it simple?

    def is_broadcastable(shp1, shp2):
        for a, b in zip(shp1[::-1], shp2[::-1]):
            if a == 1 or b == 1 or a == b:
                pass
            else:
                return False
        return True
    
    0 讨论(0)
  • 2021-01-21 17:15

    For the case of when you want to check any number of array-like objects (opposed to passing shapes), we can leverage np.nditer for broadcasting array iteration.

    def is_broadcastable(*arrays):
        try:
            np.nditer(arrays)
            return True
        except ValueError:
            return False
    

    Be aware that this only works for np.ndarray or classes that define __array__ (which will be called).

    0 讨论(0)
提交回复
热议问题