I know the method rdd.firstwfirst() which gives me the first element in an RDD.
Also there is the method rdd.take(num) Which gives me the first "num" elemen
I tried this class to fetch an item by index. First, when you construct new IndexedFetcher(rdd, itemClass)
, it counts the number of elements in each partition of the RDD. Then, when you call indexedFetcher.get(n)
, it runs a job on only the partition that contains that index.
Note that I needed to compile this using Java 1.7 instead of 1.8; as of Spark 1.1.0, the bundled org.objectweb.asm within com.esotericsoftware.reflectasm cannot read Java 1.8 classes yet (throws IllegalStateException when you try to runJob a Java 1.8 function).
import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;
import scala.reflect.ClassTag;
public static class IndexedFetcher<E> implements Serializable {
private static final long serialVersionUID = 1L;
public final RDD<E> rdd;
public Integer[] elementsPerPartitions;
private Class<?> clazz;
public IndexedFetcher(RDD<E> rdd, Class<?> clazz){
this.rdd = rdd;
this.clazz = clazz;
SparkContext context = this.rdd.context();
ClassTag<Integer> intClassTag = scala.reflect.ClassTag$.MODULE$.<Integer>apply(Integer.class);
elementsPerPartitions = (Integer[]) context.<E, Integer>runJob(rdd, IndexedFetcher.<E>countFunction(), intClassTag);
public static class IteratorCountFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, Integer> implements Serializable {
private static final long serialVersionUID = 1L;
@Override public Integer apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
int count = 0;
while (iterator.hasNext()) {
return count;
static <E> scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> countFunction() {
scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> function = new IteratorCountFunction<E>();
return function;
public E get(long index) {
long remaining = index;
long totalCount = 0;
for (int partition = 0; partition < elementsPerPartitions.length; partition++) {
if (remaining < elementsPerPartitions[partition]) {
return getWithinPartition(partition, remaining);
remaining -= elementsPerPartitions[partition];
totalCount += elementsPerPartitions[partition];
throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount));
public static class FetchWithinPartitionFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, E> implements Serializable {
private static final long serialVersionUID = 1L;
private final long indexWithinPartition;
public FetchWithinPartitionFunction(long indexWithinPartition) {
this.indexWithinPartition = indexWithinPartition;
@Override public E apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
int count = 0;
while (iterator.hasNext()) {
E element = iterator.next();
if (count == indexWithinPartition)
return element;
throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count));
public E getWithinPartition(int partition, long indexWithinPartition) {
System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition);
SparkContext context = rdd.context();
scala.Function2<TaskContext, scala.collection.Iterator<E>, E> function = new FetchWithinPartitionFunction<E>(indexWithinPartition);
scala.collection.Seq<Object> partitions = new scala.collection.mutable.WrappedArray.ofInt(new int[] {partition});
ClassTag<E> classTag = scala.reflect.ClassTag$.MODULE$.<E>apply(this.clazz);
E[] result = (E[]) context.<E, E>runJob(rdd, function, partitions, true, classTag);
return result[0];
This should be possible by first indexing the RDD. The transformation zipWithIndex
provides a stable indexing, numbering each element in its original order.
Given: rdd = (a,b,c)
val withIndex = rdd.zipWithIndex // ((a,0),(b,1),(c,2))
To lookup an element by index, this form is not useful. First we need to use the index as key:
val indexKey = withIndex.map{case (k,v) => (v,k)} //((0,a),(1,b),(2,c))
Now, it's possible to use the lookup
action in PairRDD to find an element by key:
val b = indexKey.lookup(1) // Array(b)
If you're expecting to use lookup
often on the same RDD, I'd recommend to cache the indexKey
RDD to improve performance.
How to do this using the Java API is an exercise left for the reader.
I got stuck on this for a while as well, so to expand on Maasg's answer but answering to look for a range of values by index for Java (you'll need to define the 4 variables at the top):
DataFrame df;
SQLContext sqlContext;
Long start;
Long end;
JavaPairRDD<Row, Long> indexedRDD = df.toJavaRDD().zipWithIndex();
JavaRDD filteredRDD = indexedRDD.filter((Tuple2<Row,Long> v1) -> v1._2 >= start && v1._2 < end);
DataFrame filteredDataFrame = sqlContext.createDataFrame(filteredRDD, df.schema());
Remember that when you run this code your cluster will need to have Java 8 (as a lambda expression is in use).
Also, zipWithIndex is probably expensive!