Tensorflow median value

一个人想着一个人 提交于 2019-12-05 07:05:01

edit: This answer is outdated, use Lucas Venezian Povoa's solution instead. It is simpler and faster.

You can calculate the median inside tensorflow using:

def get_median(v):
    v = tf.reshape(v, [-1])
    mid = v.get_shape()[0]//2 + 1
    return tf.nn.top_k(v, mid).values[-1]

If X is already a vector you can skip the reshaping.

If you care about the median value being the mean of the two middle elements for vectors of even size, you should use this instead:

def get_real_median(v):
    v = tf.reshape(v, [-1])
    l = v.get_shape()[0]
    mid = l//2 + 1
    val = tf.nn.top_k(v, mid).values
    if l % 2 == 1:
        return val[-1]
    else:
        return 0.5 * (val[-1] + val[-2])

For calculating median of an array with tensorflow you can use quantile function, since 50% quantile is the median.

import tensorflow as tf
import numpy as np 

np.random.seed(0)   
x = np.random.normal(3.0, .1, 100)

median = tf.contrib.distributions.percentile(x, 50.0)

tf.Session().run(median)

This code has not the same behavior of np.median because interpolation parameter approximates the result to lower, higher or nearest sample value.

If you want same behavior you could use:

median = tf.contrib.distributions.percentile(x, 50.0, interpolation='lower')
median += tf.contrib.distributions.percentile(x, 50.0, interpolation='higher')
median /= 2.
tf.Session().run(median)

Besides that, code above is equivalent to np.percentile(x, 50, interpolation='midpoint').

We can modify BlueSun's solution to be much faster on GPUs:

def get_median(v):
    v = tf.reshape(v, [-1])
    m = v.get_shape()[0]//2
    return tf.reduce_min(tf.nn.top_k(v, m, sorted=False).values)

This is as fast as (in my experience) using tf.contrib.distributions.percentile(v, 50.0), and returns one of the actual elements.

There is currently no median function in TF. The only way to use numpy operation in TF is after you run your graph:

import tensorflow as tf
import numpy as np

a = tf.random_uniform(shape=(5, 5))

with tf.Session() as sess:
    np_matrix = sess.run(a)
    print np.median(np_matrix)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!