问题
I'm getting an error when running flatMap() on a list of objects of a class. It works fine for regular python data types like int, list etc. but I'm facing an error when the list contains objects of my class. Here's the entire code:
from pyspark import SparkContext
sc = SparkContext("local","WordCountBySparkKeyword")
def func(x):
if x==2:
return [2, 3, 4]
return [1]
rdd = sc.parallelize([2])
rdd = rdd.flatMap(func) # rdd.collect() now has [2, 3, 4]
rdd = rdd.flatMap(func) # rdd.collect() now has [2, 3, 4, 1, 1]
print rdd.collect() # gives expected output
# Class I'm defining
class node(object):
def __init__(self, value):
self.value = value
# Representation, for printing node
def __repr__(self):
return self.value
def foo(x):
if x.value==2:
return [node(2), node(3), node(4)]
return [node(1)]
rdd = sc.parallelize([node(2)])
rdd = rdd.flatMap(foo) #marker 2
print rdd.collect() # rdd.collect should contain nodes with values [2, 3, 4, 1, 1]
The code works fine till marker 1(commented in code). The problem arises after marker 2. The specific error message I'm getting is AttributeError: 'module' object has no attribute 'node'
How do I resolve this error?
I'm working on ubuntu, running pyspark 1.4.1
回答1:
Error you get is completely unrelated to flatMap
. If you define node
class in your main script it is accessible on a driver but it is not distributed to the workers. To make it work you should place node
definition inside separate module and makes sure it is distributed to the workers.
- Create separate module with
node
definition, lets call itnode.py
Import this
node
class inside your main script:from node import node
Make sure module is distributed to the workers:
sc.addPyFile("node.py")
Now everything should work as expected.
On a side note:
- PEP 8 recommends CapWords for class names. It is not a hard requirement but it makes life easier
__repr__
method should return a string representation of an object. At least make sure it is astring
, but a proper representation is even better:def __repr__(self): return "node({0})".format(repr(self.value))
来源:https://stackoverflow.com/questions/32792271/flatmap-over-list-of-custom-objects-in-pyspark