How would I implement k-means with TensorFlow?

后端 未结 3 964
深忆病人
深忆病人 2021-01-31 05:12

The intro tutorial, which uses the built-in gradient descent optimizer, makes a lot of sense. However, k-means isn\'t just something I can plug into gradient descent. It seems l

相关标签:
3条回答
  • 2021-01-31 05:54

    Most of the answers I have seen so far focuses just on the 2d version (when you need to cluster points in 2 dimensions). Here is my implementation of the clustering in arbitrary dimensions.


    Basic idea of k-means algorithm in n dims:

    • generate random k starting points
    • do this till you exceed the patience or the cluster assignment does not change:
      • assign each point to the closest starting point
      • recalculate the location of each starting point by taking the average among it's cluster

    To be able to somehow validate the results I will attempt to cluster MNIST images.

    import numpy as np
    import tensorflow as tf
    from random import randint
    from collections import Counter
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets("MNIST_data/")
    X, y, k = mnist.test.images, mnist.test.labels, 10
    

    So here X is my data to cluster (10000, 784), y is the real number, and k is the number of cluster (which is the same as the number of digits. Now the actual algorithm:

    # select random points as a starting position. You can do better by randomly selecting k points.
    start_pos = tf.Variable(X[np.random.randint(X.shape[0], size=k),:], dtype=tf.float32)
    centroids = tf.Variable(start_pos.initialized_value(), 'S', dtype=tf.float32)
    
    # populate points
    points           = tf.Variable(X, 'X', dtype=tf.float32)
    ones_like        = tf.ones((points.get_shape()[0], 1))
    prev_assignments = tf.Variable(tf.zeros((points.get_shape()[0], ), dtype=tf.int64))
    
    # find the distance between all points: http://stackoverflow.com/a/43839605/1090562
    p1 = tf.matmul(
        tf.expand_dims(tf.reduce_sum(tf.square(points), 1), 1),
        tf.ones(shape=(1, k))
    )
    p2 = tf.transpose(tf.matmul(
        tf.reshape(tf.reduce_sum(tf.square(centroids), 1), shape=[-1, 1]),
        ones_like,
        transpose_b=True
    ))
    distance = tf.sqrt(tf.add(p1, p2) - 2 * tf.matmul(points, centroids, transpose_b=True))
    
    # assign each point to a closest centroid
    point_to_centroid_assignment = tf.argmin(distance, axis=1)
    
    # recalculate the centers
    total = tf.unsorted_segment_sum(points, point_to_centroid_assignment, k)
    count = tf.unsorted_segment_sum(ones_like, point_to_centroid_assignment, k)
    means = total / count
    
    # continue if there is any difference between the current and previous assignment
    is_continue = tf.reduce_any(tf.not_equal(point_to_centroid_assignment, prev_assignments))
    
    with tf.control_dependencies([is_continue]):
        loop = tf.group(centroids.assign(means), prev_assignments.assign(point_to_centroid_assignment))
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    # do many iterations. Hopefully you will stop because of has_changed is False
    has_changed, cnt = True, 0
    while has_changed and cnt < 300:
        cnt += 1
        has_changed, _ = sess.run([is_continue, loop])
    
    # see how the data is assigned
    res = sess.run(point_to_centroid_assignment)
    

    Now it is time check how good are our clusters. To do this we will group all the real numbers that appeared in the cluster together. After that we will see the most popular choices in that cluster. In a case of the perfect clustering we will have the just one value in each group. In case of random cluster each value will be approximately equally represented in the group.

    nums_in_clusters = [[] for i in xrange(10)]
    for cluster, real_num in zip(list(res), list(y)):
        nums_in_clusters[cluster].append(real_num)
    
    for i in xrange(10):
        print Counter(nums_in_clusters[i]).most_common(3)
    

    This gives me something like this:

    [(0, 738), (6, 18), (2, 11)]
    [(1, 641), (3, 53), (2, 51)]
    [(1, 488), (2, 115), (7, 56)]
    [(4, 550), (9, 533), (7, 280)]
    [(7, 634), (9, 400), (4, 302)]
    [(6, 649), (4, 27), (0, 14)]
    [(5, 269), (6, 244), (0, 161)]
    [(8, 646), (5, 164), (3, 125)]
    [(2, 698), (3, 34), (7, 14)]
    [(3, 712), (5, 290), (8, 110)]
    

    This is pretty good because majority of the counts is in the first group. You see that clustering confuses 7 and 9, 4 and 5. But 0 is clustered pretty nicely.

    A few approaches how to improve this:

    • run the algorithm a few times and select the best one (based on the distance to clusters)
    • handling cases when nothing is assigned to a cluster. In my case you will get Nan in means variable because count is 0.
    • random points initialization.
    0 讨论(0)
  • 2021-01-31 05:59

    (note: You can now get a more polished version of this code as a gist on github.)

    you can definitely do it, but you need to define your own optimization criteria (for k-means, it's usually a max iteration count and when the assignment stabilizes). Here's an example of how you might do it (there are probably more optimal ways to implement it, and definitely better ways to select the initial points). It's basically like you'd do it in numpy if you were trying really hard to stay away from doing things iteratively in python:

    import tensorflow as tf
    import numpy as np
    import time
    
    N=10000
    K=4
    MAX_ITERS = 1000
    
    start = time.time()
    
    points = tf.Variable(tf.random_uniform([N,2]))
    cluster_assignments = tf.Variable(tf.zeros([N], dtype=tf.int64))
    
    # Silly initialization:  Use the first two points as the starting                
    # centroids.  In the real world, do this better.                                 
    centroids = tf.Variable(tf.slice(points.initialized_value(), [0,0], [K,2]))
    
    # Replicate to N copies of each centroid and K copies of each                    
    # point, then subtract and compute the sum of squared distances.                 
    rep_centroids = tf.reshape(tf.tile(centroids, [N, 1]), [N, K, 2])
    rep_points = tf.reshape(tf.tile(points, [1, K]), [N, K, 2])
    sum_squares = tf.reduce_sum(tf.square(rep_points - rep_centroids),
                                reduction_indices=2)
    
    # Use argmin to select the lowest-distance point                                 
    best_centroids = tf.argmin(sum_squares, 1)
    did_assignments_change = tf.reduce_any(tf.not_equal(best_centroids,
                                                        cluster_assignments))
    
    def bucket_mean(data, bucket_ids, num_buckets):
        total = tf.unsorted_segment_sum(data, bucket_ids, num_buckets)
        count = tf.unsorted_segment_sum(tf.ones_like(data), bucket_ids, num_buckets)
        return total / count
    
    means = bucket_mean(points, best_centroids, K)
    
    # Do not write to the assigned clusters variable until after                     
    # computing whether the assignments have changed - hence with_dependencies
    with tf.control_dependencies([did_assignments_change]):
        do_updates = tf.group(
            centroids.assign(means),
            cluster_assignments.assign(best_centroids))
    
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    
    changed = True
    iters = 0
    
    while changed and iters < MAX_ITERS:
        iters += 1
        [changed, _] = sess.run([did_assignments_change, do_updates])
    
    [centers, assignments] = sess.run([centroids, cluster_assignments])
    end = time.time()
    print ("Found in %.2f seconds" % (end-start)), iters, "iterations"
    print "Centroids:"
    print centers
    print "Cluster assignments:", assignments
    

    (Note that a real implementation would need to be more careful about initial cluster selection, avoiding problem cases with all points going to one cluster, etc. This is just a quick demo. I've updated my answer from earlier to make it a bit more clear and "example-worthy".)

    0 讨论(0)
  • Nowadays you could directly use (or take inspiration from) the KMeansClustering Estimator. You can take a look at its implementation on GitHub.

    0 讨论(0)
提交回复
热议问题