I have the following PyTorch module I made:
class Partition(nn.Module): def __init__(self, decider, left_child, right_child): super().__init__()