I have the following subclass model:
class MyModel(tf.keras.Model): def __init__(self, dropout_ratio=0.25, activat