I am implementing a triplet network in Pytorch where the 3 instances (sub-networks) share the same weights. Since the weights are shared, I implemented it as a singl