When we want to use distributed TensorFlow, we will create a parameter server using
tf.train.Server.join()
However, I can\'t find any way to sh
You can have parameter server processes die on demand by using session.run(dequeue_op)
instead of server.join()
and having another process enqueue something onto that queue when you want this process to die.
So for k
parameter server shards you could create k
queues, with unique shared_name
property and try to dequeue
from that queue. When you want to bring down the servers, you loop over all queues and enqueue
a token onto each queue. This would cause session.run
to unblock and Python process will run to the end and quit, bringing down the server.
Below is a self-contained example with 2 shards taken from: https://gist.github.com/yaroslavvb/82a5b5302449530ca5ff59df520c369e
(for multi worker/multi shard example, see https://gist.github.com/yaroslavvb/ea1b1bae0a75c4aae593df7eca72d9ca)
import subprocess
import tensorflow as tf
import time
import sys
flags = tf.flags
flags.DEFINE_string("port1", "12222", "port of worker1")
flags.DEFINE_string("port2", "12223", "port of worker2")
flags.DEFINE_string("task", "", "internal use")
FLAGS = flags.FLAGS
# setup local cluster from flags
host = "127.0.0.1:"
cluster = {"worker": [host+FLAGS.port1, host+FLAGS.port2]}
clusterspec = tf.train.ClusterSpec(cluster).as_cluster_def()
if __name__=='__main__':
if not FLAGS.task: # start servers and run client
# launch distributed service
def runcmd(cmd): subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT)
runcmd("python %s --task=0"%(sys.argv[0]))
runcmd("python %s --task=1"%(sys.argv[0]))
time.sleep(1)
# bring down distributed service
sess = tf.Session("grpc://"+host+FLAGS.port1)
queue0 = tf.FIFOQueue(1, tf.int32, shared_name="queue0")
queue1 = tf.FIFOQueue(1, tf.int32, shared_name="queue1")
with tf.device("/job:worker/task:0"):
add_op0 = tf.add(tf.ones(()), tf.ones(()))
with tf.device("/job:worker/task:1"):
add_op1 = tf.add(tf.ones(()), tf.ones(()))
print("Running computation on server 0")
print(sess.run(add_op0))
print("Running computation on server 1")
print(sess.run(add_op1))
print("Bringing down server 0")
sess.run(queue0.enqueue(1))
print("Bringing down server 1")
sess.run(queue1.enqueue(1))
else: # Launch TensorFlow server
server = tf.train.Server(clusterspec, config=None,
job_name="worker",
task_index=int(FLAGS.task))
print("Starting server "+FLAGS.task)
sess = tf.Session(server.target)
queue = tf.FIFOQueue(1, tf.int32, shared_name="queue"+FLAGS.task)
sess.run(queue.dequeue())
print("Terminating server"+FLAGS.task)