Test if an array is broadcastable to a shape?

后端 未结 5 1856
灰色年华
灰色年华 2021-01-21 16:39

What is the best way to test whether an array can be broadcast to a given shape?

The \"pythonic\" approach of trying doesn\'t work for my case, because the

5条回答
  •  佛祖请我去吃肉
    2021-01-21 17:04

    To generalize this to arbitrarily many shapes, you can do so as follows:

    def is_broadcast_compatible(*shapes):
        if len(shapes) < 2:
            return True
        else:
            for dim in zip(*[shape[::-1] for shape in shapes]):
                if len(set(dim).union({1})) <= 2:
                    pass
                else:
                    return False
            return True
    

    The corresponding test case is as follows:

    import unittest
    
    
    class TestBroadcastCompatibility(unittest.TestCase):
        def check_true(self, *shapes):
            self.assertTrue(is_broadcast_compatible(*shapes), msg=shapes)
    
        def check_false(self, *shapes):
            self.assertFalse(is_broadcast_compatible(*shapes), msg=shapes)
    
        def test(self):
            self.check_true((1, 2, 3), (1, 2, 3))
            self.check_true((3, 1, 3), (3, 3, 3))
            self.check_true((1,), (2,), (2,))
    
            self.check_false((1, 2, 3), (1, 2, 2))
            self.check_false((1, 2, 3), (1, 2, 3, 4))
            self.check_false((1,), (2,), (3,))
    

提交回复
热议问题