We have a very standard Spark job which reads log files from s3 and then does some processing over them. Very basic Spark stuff...
val logs = sc.textFile(somePat
One approach is to use the one-parameter overload of collect
(instead of map
or flatMap
) and a PartialFunction
. This is a little tricky if the partial function you need isn't completely trivial. In fact yours probably won't be because you need to parse and validate, which I'll model below with two partial functions (although the first one happens to be defined for all inputs).
// this doesn't really need to be a partial function but we'll
// want to compose it with one and end up with a partial function
val split: PartialFunction[String, Array[String]] = {
case log => log.split("\t")
}
// this really needs to be a partial function
val validate: PartialFunction[Array[String], Array[String]] = {
case lines if lines.length > 2 => lines
}
val splitAndValidate = split andThen validate
val logs = sc.parallelize(Seq("a\tb", "u\tv\tw", "a", "x\ty\tz"), 4)
// only accept the logs with more than two entries
val validRows = logs.collect(splitAndValidate)
This is perfectly good Scala but it doesn't work because splitAndValidate
isn't serializable and we're using Spark. (Note that split
and validate
are serializable: the problem lies with composition!) So, we need to make a PartialFunction
that is serializable:
class LogValidator extends PartialFunction[String, Array[String]] with Serializable {
private val validate: PartialFunction[Array[String], Array[String]] = {
case lines if lines.length > 2 => lines
}
override def apply(log: String) : Array[String] = {
validate(log.split("\t"))
}
override def isDefinedAt(log: String) : Boolean = {
validate.isDefinedAt(log.split("\t"))
}
}
Then we can call
val validRows = logs.collect(new LogValidator())