Shut down server in TensorFlow

前端 未结 3 1856
小蘑菇
小蘑菇 2021-02-02 02:10

When we want to use distributed TensorFlow, we will create a parameter server using

tf.train.Server.join()

However, I can\'t find any way to sh

3条回答
  •  隐瞒了意图╮
    2021-02-02 02:40

    You can have parameter server processes die on demand by using session.run(dequeue_op) instead of server.join() and having another process enqueue something onto that queue when you want this process to die.

    So for k parameter server shards you could create k queues, with unique shared_name property and try to dequeue from that queue. When you want to bring down the servers, you loop over all queues and enqueue a token onto each queue. This would cause session.run to unblock and Python process will run to the end and quit, bringing down the server.

    Below is a self-contained example with 2 shards taken from: https://gist.github.com/yaroslavvb/82a5b5302449530ca5ff59df520c369e

    (for multi worker/multi shard example, see https://gist.github.com/yaroslavvb/ea1b1bae0a75c4aae593df7eca72d9ca)

    import subprocess
    import tensorflow as tf
    import time
    import sys
    
    flags = tf.flags
    flags.DEFINE_string("port1", "12222", "port of worker1")
    flags.DEFINE_string("port2", "12223", "port of worker2")
    flags.DEFINE_string("task", "", "internal use")
    FLAGS = flags.FLAGS
    
    # setup local cluster from flags
    host = "127.0.0.1:"
    cluster = {"worker": [host+FLAGS.port1, host+FLAGS.port2]}
    clusterspec = tf.train.ClusterSpec(cluster).as_cluster_def()
    
    if __name__=='__main__':
      if not FLAGS.task:  # start servers and run client
    
          # launch distributed service
          def runcmd(cmd): subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT)
          runcmd("python %s --task=0"%(sys.argv[0]))
          runcmd("python %s --task=1"%(sys.argv[0]))
          time.sleep(1)
    
          # bring down distributed service
          sess = tf.Session("grpc://"+host+FLAGS.port1)
          queue0 = tf.FIFOQueue(1, tf.int32, shared_name="queue0")
          queue1 = tf.FIFOQueue(1, tf.int32, shared_name="queue1")
          with tf.device("/job:worker/task:0"):
              add_op0 = tf.add(tf.ones(()), tf.ones(()))
          with tf.device("/job:worker/task:1"):
              add_op1 = tf.add(tf.ones(()), tf.ones(()))
    
          print("Running computation on server 0")
          print(sess.run(add_op0))
          print("Running computation on server 1")
          print(sess.run(add_op1))
    
          print("Bringing down server 0")
          sess.run(queue0.enqueue(1))
          print("Bringing down server 1")
          sess.run(queue1.enqueue(1))
    
      else: # Launch TensorFlow server
        server = tf.train.Server(clusterspec, config=None,
                                 job_name="worker",
                                 task_index=int(FLAGS.task))
        print("Starting server "+FLAGS.task)
        sess = tf.Session(server.target)
        queue = tf.FIFOQueue(1, tf.int32, shared_name="queue"+FLAGS.task)
        sess.run(queue.dequeue())
        print("Terminating server"+FLAGS.task)
    

提交回复
热议问题