I made a zero_mask2 function (copied below) in the middle of my pytorch file. However, it is too slow. So, I\'m looking for a better way.
zero_mask2
First, Let me e