How would I implement k-means with TensorFlow?

后端 未结 3 962
深忆病人
深忆病人 2021-01-31 05:12

The intro tutorial, which uses the built-in gradient descent optimizer, makes a lot of sense. However, k-means isn\'t just something I can plug into gradient descent. It seems l

3条回答
  •  粉色の甜心
    2021-01-31 05:54

    Most of the answers I have seen so far focuses just on the 2d version (when you need to cluster points in 2 dimensions). Here is my implementation of the clustering in arbitrary dimensions.


    Basic idea of k-means algorithm in n dims:

    • generate random k starting points
    • do this till you exceed the patience or the cluster assignment does not change:
      • assign each point to the closest starting point
      • recalculate the location of each starting point by taking the average among it's cluster

    To be able to somehow validate the results I will attempt to cluster MNIST images.

    import numpy as np
    import tensorflow as tf
    from random import randint
    from collections import Counter
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets("MNIST_data/")
    X, y, k = mnist.test.images, mnist.test.labels, 10
    

    So here X is my data to cluster (10000, 784), y is the real number, and k is the number of cluster (which is the same as the number of digits. Now the actual algorithm:

    # select random points as a starting position. You can do better by randomly selecting k points.
    start_pos = tf.Variable(X[np.random.randint(X.shape[0], size=k),:], dtype=tf.float32)
    centroids = tf.Variable(start_pos.initialized_value(), 'S', dtype=tf.float32)
    
    # populate points
    points           = tf.Variable(X, 'X', dtype=tf.float32)
    ones_like        = tf.ones((points.get_shape()[0], 1))
    prev_assignments = tf.Variable(tf.zeros((points.get_shape()[0], ), dtype=tf.int64))
    
    # find the distance between all points: http://stackoverflow.com/a/43839605/1090562
    p1 = tf.matmul(
        tf.expand_dims(tf.reduce_sum(tf.square(points), 1), 1),
        tf.ones(shape=(1, k))
    )
    p2 = tf.transpose(tf.matmul(
        tf.reshape(tf.reduce_sum(tf.square(centroids), 1), shape=[-1, 1]),
        ones_like,
        transpose_b=True
    ))
    distance = tf.sqrt(tf.add(p1, p2) - 2 * tf.matmul(points, centroids, transpose_b=True))
    
    # assign each point to a closest centroid
    point_to_centroid_assignment = tf.argmin(distance, axis=1)
    
    # recalculate the centers
    total = tf.unsorted_segment_sum(points, point_to_centroid_assignment, k)
    count = tf.unsorted_segment_sum(ones_like, point_to_centroid_assignment, k)
    means = total / count
    
    # continue if there is any difference between the current and previous assignment
    is_continue = tf.reduce_any(tf.not_equal(point_to_centroid_assignment, prev_assignments))
    
    with tf.control_dependencies([is_continue]):
        loop = tf.group(centroids.assign(means), prev_assignments.assign(point_to_centroid_assignment))
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    # do many iterations. Hopefully you will stop because of has_changed is False
    has_changed, cnt = True, 0
    while has_changed and cnt < 300:
        cnt += 1
        has_changed, _ = sess.run([is_continue, loop])
    
    # see how the data is assigned
    res = sess.run(point_to_centroid_assignment)
    

    Now it is time check how good are our clusters. To do this we will group all the real numbers that appeared in the cluster together. After that we will see the most popular choices in that cluster. In a case of the perfect clustering we will have the just one value in each group. In case of random cluster each value will be approximately equally represented in the group.

    nums_in_clusters = [[] for i in xrange(10)]
    for cluster, real_num in zip(list(res), list(y)):
        nums_in_clusters[cluster].append(real_num)
    
    for i in xrange(10):
        print Counter(nums_in_clusters[i]).most_common(3)
    

    This gives me something like this:

    [(0, 738), (6, 18), (2, 11)]
    [(1, 641), (3, 53), (2, 51)]
    [(1, 488), (2, 115), (7, 56)]
    [(4, 550), (9, 533), (7, 280)]
    [(7, 634), (9, 400), (4, 302)]
    [(6, 649), (4, 27), (0, 14)]
    [(5, 269), (6, 244), (0, 161)]
    [(8, 646), (5, 164), (3, 125)]
    [(2, 698), (3, 34), (7, 14)]
    [(3, 712), (5, 290), (8, 110)]
    

    This is pretty good because majority of the counts is in the first group. You see that clustering confuses 7 and 9, 4 and 5. But 0 is clustered pretty nicely.

    A few approaches how to improve this:

    • run the algorithm a few times and select the best one (based on the distance to clusters)
    • handling cases when nothing is assigned to a cluster. In my case you will get Nan in means variable because count is 0.
    • random points initialization.

提交回复
热议问题