I have a problem with reshaping my array when being supplied to the Prototypical Network. The code can be found bellow:
class PrototypicalNetwork(LossFunction