Can Numba be used with Tensorflow?

笑着哭i 提交于 2020-06-10 02:53:26

问题


Can Numba be used to compile Python code which interfaces with Tensorflow? I.e. computations outside of the Tensorflow universe would run with Numba for speed. I have not found any resources on how to do this.


回答1:


You can use tf.numpy_function, or tf.py_func to wrap a python function and use it as a TensorFlow op. Here is an example which I used:

@jit
def dice_coeff_nb(y_true, y_pred):
    "Calculates dice coefficient"
    smooth = np.float32(1)
    y_true_f = np.reshape(y_true, [-1])
    y_pred_f = np.reshape(y_pred, [-1])
    intersection = np.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (np.sum(y_true_f) +
                                            np.sum(y_pred_f) + smooth)
    return score

@jit
def dice_loss_nb(y_true, y_pred):
    "Calculates dice loss"
    loss = 1 - dice_coeff_nb(y_true, y_pred)
    return loss

def bce_dice_loss_nb(y_true, y_pred):
    "Adds dice_loss to categorical_crossentropy"
    loss =  tf.numpy_function(dice_loss_nb, [y_true, y_pred], tf.float64) + \
            tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    return loss

Then I used this loss function in training a tf.keras model:

...
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss=bce_dice_loss_nb)



回答2:


I know that this does not directly answer you question, but it might be a good alternative. Numba is using just-in-time (JIT) Compilation. So, you can follow the instruction at the official TensorFlow documentation here on how to use JIT (but not in Numba ecosystem) in TensorFlow.



来源:https://stackoverflow.com/questions/41132633/can-numba-be-used-with-tensorflow

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!