keras (tensorflow backend) conditional assignment with K.switch()

断了今生、忘了曾经 提交于 2021-01-01 13:34:30

问题


I'm trying to implement something like

if np.max(subgrid) == np.min(subgrid):
    middle_middle = cur_subgrid + 1
else:
    middle_middle = cur_subgrid

Since the condition can only be determined at run-time, I'm using Keras syntax as following

middle_middle = K.switch(K.max(subgrid) == K.min(subgrid), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

But I'm getting this error:

<ipython-input-112-0504ce070e71> in col_loop(j, gray_map, mask_A)
     56 
     57 
---> 58             middle_middle = K.switch(K.max(subgrid) == K.min(subgrid), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
     59 
     60             print ('ml',middle_left.shape)

/nfs/isicvlnas01/share/anaconda3/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in switch(condition, then_expression, else_expression)    2561         The selected tensor.    2562     """
-> 2563     if condition.dtype != tf.bool:    2564         condition = tf.cast(condition, 'bool')    2565     if not callable(then_expression):

AttributeError: 'bool' object has no attribute 'dtype'

middle_middle, cur_subgrid, and subgrid are all NxN tensors. Any help is appreciated.


回答1:


I think the problem is that with K.max(subgrid) == K.min(subgrid) you're creating a python boolean comparing two tensor objects, not a tensorflow boolean tensor containing the value of the comparison of the values of the two input tensors.

In other words, what you have written will be evaluated as

K.switch(True, lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

instead of

comparison = ... # Some tensor, that at runtime will contain True if min and max are the same, False otherwise. 
K.switch(comparison , lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

So what you need to do is to use keras.backend.equal() instead of ==:

K.switch(K.equal(K.max(subgrid),K.min(subgrid)), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)


来源:https://stackoverflow.com/questions/52854179/keras-tensorflow-backend-conditional-assignment-with-k-switch

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!