How to run a function on all Spark workers before processing data in PySpark?

后端 未结 2 1483
南方客
南方客 2020-12-03 01:24

I\'m running a Spark Streaming task in a cluster using YARN. Each node in the cluster runs multiple spark workers. Before the streaming starts I want to execute a \"setup\"

相关标签:
2条回答
  • 2020-12-03 01:53

    This is a typical use case for Spark's broadcast variables. Let's say fetch_models returns the models rather than saving them locally, you would do something like:

    bc_models = sc.broadcast(fetch_models())
    
    spark_partitions = config.get(ConfigKeys.SPARK_PARTITIONS)
    stream.union(*create_kafka_streams())\
        .repartition(spark_partitions)\
        .foreachRDD(lambda rdd: rdd.foreachPartition(lambda partition: spam.on_partition(config, partition, bc_models.value)))
    

    This does assume that your models fit in memory, on the driver and the executors.

    You may be worried that broadcasting the models from the single driver to all the executors is inefficient, but it uses 'efficient broadcast algorithms' that can outperform distributing through HDFS significantly according to this analysis

    0 讨论(0)
  • 2020-12-03 01:57

    If all you want is to distribute a file between worker machines the simplest approach is to use SparkFiles mechanism:

    some_path = ...  # local file, a file in DFS, an HTTP, HTTPS or FTP URI.
    sc.addFile(some_path)
    

    and retrieve it on the workers using SparkFiles.get and standard IO tools:

    from pyspark import SparkFiles
    
    with open(SparkFiles.get(some_path)) as fw:
        ... # Do something
    

    If you want to make sure that model is actually loaded the simplest approach is to load on module import. Assuming config can be used to retrieve model path:

    • model.py:

      from pyspark import SparkFiles
      
      config = ...
      class MyClassifier:
          clf = None
      
          @staticmethod
          def is_loaded():
              return MyClassifier.clf is not None
      
          @staticmethod
          def load_models(config):
              path = SparkFiles.get(config.get("model_file"))
              MyClassifier.clf = load_from_file(path)
      
      # Executed once per interpreter 
      MyClassifier.load_models(config)  
      
    • main.py:

      from pyspark import SparkContext
      
      config = ...
      
      sc = SparkContext("local", "foo")
      
      # Executed before StreamingContext starts
      sc.addFile(config.get("model_file"))
      sc.addPyFile("model.py")
      
      import model
      
      ssc = ...
      stream = ...
      stream.map(model.MyClassifier.do_something).pprint()
      
      ssc.start()
      ssc.awaitTermination()
      
    0 讨论(0)
提交回复
热议问题