I\'m doing the online Computer Vision course by UMich and am new to PyTorch. One of the assignment questions is on batch matrix multiplication, where we have to find the batch m