Есть ли способ передать предельный параметр functions.collect_set в Spark?

Я имею дело с столбцом чисел в большом искровом DataFrame, и я хотел бы создать новый столбец, который хранит объединенный список уникальных чисел, которые появляются в этом столбце.

В основном именно то, что выполняет функция .collect_set. Тем не менее, мне нужно всего до 1000 элементов в агрегированном списке. Можно ли каким-либо образом передать этот параметр функции functions.collect_set() или любым другим способом получить только до 1000 элементов в агрегированном списке без использования UDAF?

Поскольку столбец настолько велик, я хотел бы избежать сбора всех элементов и последующего обрезки списка.

Спасибо!

Ответ 1

Мое решение очень похоже на ответ Локи с collect_set_limit.


Я бы использовал UDF, который будет делать то, что вы хотите после collect_set (или collect_list) или гораздо более сложного UDAF.

Учитывая больший опыт работы с UDF, я бы пошел с этим первым. Хотя пользовательские функции не оптимизированы, для этого варианта использования это нормально.

val limitUDF = udf { (nums: Seq[Long], limit: Int) => nums.take(limit) }
val sample = spark.range(50).withColumn("key", $"id" % 5)

scala> sample.groupBy("key").agg(collect_set("id") as "all").show(false)
+---+--------------------------------------+
|key|all                                   |
+---+--------------------------------------+
|0  |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|
|1  |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|
|3  |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|
|2  |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|
|4  |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|
+---+--------------------------------------+

scala> sample.
  groupBy("key").
  agg(collect_set("id") as "all").
  withColumn("limit(3)", limitUDF($"all", lit(3))).
  show(false)
+---+--------------------------------------+------------+
|key|all                                   |limit(3)    |
+---+--------------------------------------+------------+
|0  |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|[0, 15, 30] |
|1  |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|[1, 16, 31] |
|3  |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|[33, 48, 13]|
|2  |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|[12, 27, 37]|
|4  |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|[9, 19, 34] |
+---+--------------------------------------+------------+

См функции объекта (для udf функции Docs).

Ответ 2

Я использую измененную копию функций collect_set и collect_list; из-за областей кода измененные копии должны быть в том же пути пакета, что и оригиналы. Связанный код работает для Spark 2.1.0; если вы используете предыдущую версию, сигнатуры методов могут отличаться.

Бросьте этот файл (https://gist.github.com/lokkju/06323e88746c85b2ce4de3ea9cdef9bc) в свой проект как src/main/org/apache/spark/sql/catal/expression/collect_limit.scala

используйте его как:

import org.apache.spark.sql.catalyst.expression.collect_limit._
df.groupBy('set_col).agg(collect_set_limit('set_col,1000)

Ответ 3

 scala> df.show
    +---+-----+----+--------+
    | C0|   C1|  C2|      C3|
    +---+-----+----+--------+
    | 10| Name|2016| Country|
    | 11|Name1|2016|country1|
    | 10| Name|2016| Country|
    | 10| Name|2016| Country|
    | 12|Name2|2017|Country2|
    +---+-----+----+--------+

scala> df.groupBy("C1").agg(sum("C0"))
res36: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]

scala> res36.show
+-----+-------+
|   C1|sum(C0)|
+-----+-------+
|Name1|     11|
|Name2|     12|
| Name|     30|
+-----+-------+

scala> df.limit(2).groupBy("C1").agg(sum("C0"))
    res33: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]

    scala> res33.show
    +-----+-------+
    |   C1|sum(C0)|
    +-----+-------+
    | Name|     10|
    |Name1|     11|
    +-----+-------+



    scala> df.groupBy("C1").agg(sum("C0")).limit(2)
res2: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]

scala> res2.show
+-----+-------+
|   C1|sum(C0)|
+-----+-------+
|Name1|     11|
|Name2|     12|
+-----+-------+

scala> df.distinct
res8: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]

scala> res8.show
+---+-----+----+--------+
| C0|   C1|  C2|      C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 10| Name|2016| Country|
| 12|Name2|2017|Country2|
+---+-----+----+--------+

scala> df.dropDuplicates(Array("c1"))
res11: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]

scala> res11.show
+---+-----+----+--------+                                                       
| C0|   C1|  C2|      C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 12|Name2|2017|Country2|
| 10| Name|2016| Country|
+---+-----+----+--------+

Ответ 4

использовать

val firstThousand = rdd.take(1000)

Вернет первый 1000. Сбор также имеет функцию фильтра, которая может быть предоставлена. Это позволит вам быть более конкретным относительно того, что возвращается.

Ответ 5

Как уже упоминалось в других ответах, эффективный способ сделать это - написать UDAF. К сожалению, API UDAF на самом деле не так расширяемо, как агрегатные функции, которые поставляются с искрой. Однако вы можете использовать их внутренние API-интерфейсы, чтобы использовать внутренние функции и делать то, что вам нужно.

Вот реализация для collect_set_limit которая в основном является копией прошлой внутренней Spark-функции CollectSet. Я бы просто расширил его, но это класс дела. На самом деле все, что нужно, это переопределить методы update и merge для соблюдения переданного ограничения:

case class CollectSetLimit(
    child: Expression,
    limitExp: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {

  val limit = limitExp.eval( null ).asInstanceOf[Int]

  def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty

  override def update(buffer: mutable.HashSet[Any], input: InternalRow): mutable.HashSet[Any] = {
    if( buffer.size < limit ) super.update(buffer, input)
    else buffer
  }

  override def merge(buffer: mutable.HashSet[Any], other: mutable.HashSet[Any]): mutable.HashSet[Any] = {
    if( buffer.size >= limit ) buffer
    else buffer ++= other.take( limit - buffer.size )
  }

  override def prettyName: String = "collect_set_limit"
}

И чтобы фактически зарегистрировать это, мы можем сделать это через внутреннюю FunctionRegistry Spark FunctionRegistry которая берет имя и конструктор, который фактически является функцией, которая создает CollectSetLimit используя предоставленные выражения:

val collectSetBuilder = (args: Seq[Expression]) => CollectSetLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_set_limit", collectSetBuilder )

Редактировать:

Оказывается, добавление его во встроенную систему работает только в том случае, если вы еще не создали SparkContext, поскольку при запуске он становится неизменным клоном. Если у вас есть существующий контекст, то это должно работать, чтобы добавить его с отражением:

val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_set_limit", collectSetBuilder )