Is there a way to skip/throw-out/ignore records in Spark during a map?

后端 未结 2 822
忘了有多久
忘了有多久 2021-02-14 04:13

We have a very standard Spark job which reads log files from s3 and then does some processing over them. Very basic Spark stuff...

val logs = sc.textFile(somePat         


        
2条回答
  •  我在风中等你
    2021-02-14 04:51

    One approach is to use the one-parameter overload of collect (instead of map or flatMap) and a PartialFunction. This is a little tricky if the partial function you need isn't completely trivial. In fact yours probably won't be because you need to parse and validate, which I'll model below with two partial functions (although the first one happens to be defined for all inputs).

    // this doesn't really need to be a partial function but we'll 
    // want to compose it with one and end up with a partial function
    val split: PartialFunction[String, Array[String]] = {
      case log => log.split("\t")
    }
    
    // this really needs to be a partial function
    val validate: PartialFunction[Array[String], Array[String]] = {
      case lines if lines.length > 2 => lines
    }
    
    val splitAndValidate = split andThen validate
    
    val logs = sc.parallelize(Seq("a\tb", "u\tv\tw", "a", "x\ty\tz"), 4)
    
    // only accept the logs with more than two entries
    val validRows = logs.collect(splitAndValidate)
    

    This is perfectly good Scala but it doesn't work because splitAndValidate isn't serializable and we're using Spark. (Note that split and validate are serializable: the problem lies with composition!) So, we need to make a PartialFunction that is serializable:

    class LogValidator extends PartialFunction[String, Array[String]] with Serializable {
    
      private val validate: PartialFunction[Array[String], Array[String]] = {
        case lines if lines.length > 2 => lines
      }
    
      override def apply(log: String) : Array[String] = {
        validate(log.split("\t"))
      }
    
      override def isDefinedAt(log: String) : Boolean = {
        validate.isDefinedAt(log.split("\t"))
      }
    
    }
    

    Then we can call

    val validRows = logs.collect(new LogValidator())
    

提交回复
热议问题