Я знаю метод rdd.first(), который дает мне первый элемент в RDD.
Также существует метод rdd.take(num), который дает мне первые "num" элементы.
Но разве нет возможности получить элемент по индексу?
Спасибо.
Я знаю метод rdd.first(), который дает мне первый элемент в RDD.
Также существует метод rdd.take(num), который дает мне первые "num" элементы.
Но разве нет возможности получить элемент по индексу?
Спасибо.
Это должно быть возможно, сначала индексируя RDD. Преобразование zipWithIndex
обеспечивает стабильную индексацию, нумерацию каждого элемента в исходном порядке.
Учитывая: rdd = (a,b,c)
val withIndex = rdd.zipWithIndex // ((a,0),(b,1),(c,2))
Чтобы найти элемент по индексу, эта форма не полезна. Сначала нам нужно использовать индекс как ключ:
val indexKey = withIndex.map{case (k,v) => (v,k)} //((0,a),(1,b),(2,c))
Теперь можно использовать действие lookup
в PairRDD для поиска элемента по ключу:
val b = indexKey.lookup(1) // Array(b)
Если вы ожидаете часто использовать lookup
на одном и том же RDD, я бы рекомендовал кэшировать RDD indexKey
для повышения производительности.
Как это сделать с помощью API Java - это упражнение, оставшееся для читателя.
Я попробовал этот класс для выбора элемента по индексу. Во-первых, при построении new IndexedFetcher(rdd, itemClass)
он подсчитывает количество элементов в каждом разделе RDD. Затем, когда вы вызываете indexedFetcher.get(n)
, он запускает задание только для раздела, содержащего этот индекс.
Обратите внимание, что мне нужно было скомпилировать это с использованием Java 1.7 вместо 1.8; с Spark 1.1.0 связанный org.objectweb.asm в com.esotericsoftware.reflectasm еще не может читать классы Java 1.8 (бросает IllegalStateException при попытке запустить функцию Java 1.8).
import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;
import scala.reflect.ClassTag;
public static class IndexedFetcher<E> implements Serializable {
private static final long serialVersionUID = 1L;
public final RDD<E> rdd;
public Integer[] elementsPerPartitions;
private Class<?> clazz;
public IndexedFetcher(RDD<E> rdd, Class<?> clazz){
this.rdd = rdd;
this.clazz = clazz;
SparkContext context = this.rdd.context();
ClassTag<Integer> intClassTag = scala.reflect.ClassTag$.MODULE$.<Integer>apply(Integer.class);
elementsPerPartitions = (Integer[]) context.<E, Integer>runJob(rdd, IndexedFetcher.<E>countFunction(), intClassTag);
}
public static class IteratorCountFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, Integer> implements Serializable {
private static final long serialVersionUID = 1L;
@Override public Integer apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
int count = 0;
while (iterator.hasNext()) {
count++;
iterator.next();
}
return count;
}
}
static <E> scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> countFunction() {
scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> function = new IteratorCountFunction<E>();
return function;
}
public E get(long index) {
long remaining = index;
long totalCount = 0;
for (int partition = 0; partition < elementsPerPartitions.length; partition++) {
if (remaining < elementsPerPartitions[partition]) {
return getWithinPartition(partition, remaining);
}
remaining -= elementsPerPartitions[partition];
totalCount += elementsPerPartitions[partition];
}
throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount));
}
public static class FetchWithinPartitionFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, E> implements Serializable {
private static final long serialVersionUID = 1L;
private final long indexWithinPartition;
public FetchWithinPartitionFunction(long indexWithinPartition) {
this.indexWithinPartition = indexWithinPartition;
}
@Override public E apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
int count = 0;
while (iterator.hasNext()) {
E element = iterator.next();
if (count == indexWithinPartition)
return element;
count++;
}
throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count));
}
}
public E getWithinPartition(int partition, long indexWithinPartition) {
System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition);
SparkContext context = rdd.context();
scala.Function2<TaskContext, scala.collection.Iterator<E>, E> function = new FetchWithinPartitionFunction<E>(indexWithinPartition);
scala.collection.Seq<Object> partitions = new scala.collection.mutable.WrappedArray.ofInt(new int[] {partition});
ClassTag<E> classTag = scala.reflect.ClassTag$.MODULE$.<E>apply(this.clazz);
E[] result = (E[]) context.<E, E>runJob(rdd, function, partitions, true, classTag);
return result[0];
}
}
Я тоже застрял на этом, поэтому, чтобы расширить ответ Maasg, но отвечая на поиск диапазона значений по индексу для Java (вам нужно будет определить 4 переменные вверху):
DataFrame df;
SQLContext sqlContext;
Long start;
Long end;
JavaPairRDD<Row, Long> indexedRDD = df.toJavaRDD().zipWithIndex();
JavaRDD filteredRDD = indexedRDD.filter((Tuple2<Row,Long> v1) -> v1._2 >= start && v1._2 < end);
DataFrame filteredDataFrame = sqlContext.createDataFrame(filteredRDD, df.schema());
Помните, что при запуске этого кода в вашем кластере должен быть Java 8 (поскольку используется выражение лямбда).
Кроме того, zipWithIndex, вероятно, дорогой!