I\'m running a Spark Streaming task in a cluster using YARN. Each node in the cluster runs multiple spark workers. Before the streaming starts I want to execute a \"setup\"
This is a typical use case for Spark's broadcast variables. Let's say fetch_models
returns the models rather than saving them locally, you would do something like:
bc_models = sc.broadcast(fetch_models())
spark_partitions = config.get(ConfigKeys.SPARK_PARTITIONS)
stream.union(*create_kafka_streams())\
.repartition(spark_partitions)\
.foreachRDD(lambda rdd: rdd.foreachPartition(lambda partition: spam.on_partition(config, partition, bc_models.value)))
This does assume that your models fit in memory, on the driver and the executors.
You may be worried that broadcasting the models from the single driver to all the executors is inefficient, but it uses 'efficient broadcast algorithms' that can outperform distributing through HDFS significantly according to this analysis
If all you want is to distribute a file between worker machines the simplest approach is to use SparkFiles mechanism:
some_path = ... # local file, a file in DFS, an HTTP, HTTPS or FTP URI.
sc.addFile(some_path)
and retrieve it on the workers using SparkFiles.get
and standard IO tools:
from pyspark import SparkFiles
with open(SparkFiles.get(some_path)) as fw:
... # Do something
If you want to make sure that model is actually loaded the simplest approach is to load on module import. Assuming config
can be used to retrieve model path:
model.py
:
from pyspark import SparkFiles
config = ...
class MyClassifier:
clf = None
@staticmethod
def is_loaded():
return MyClassifier.clf is not None
@staticmethod
def load_models(config):
path = SparkFiles.get(config.get("model_file"))
MyClassifier.clf = load_from_file(path)
# Executed once per interpreter
MyClassifier.load_models(config)
main.py
:
from pyspark import SparkContext
config = ...
sc = SparkContext("local", "foo")
# Executed before StreamingContext starts
sc.addFile(config.get("model_file"))
sc.addPyFile("model.py")
import model
ssc = ...
stream = ...
stream.map(model.MyClassifier.do_something).pprint()
ssc.start()
ssc.awaitTermination()