collect_list, сохраняя порядок на основе другой переменной

Я пытаюсь создать новый столбец списков в Pyspark, используя агрегацию groupby в существующем наборе столбцов. Примерный входной фрейм данных приведен ниже:

------------------------
id | date        | value
------------------------
1  |2014-01-03   | 10 
1  |2014-01-04   | 5
1  |2014-01-05   | 15
1  |2014-01-06   | 20
2  |2014-02-10   | 100   
2  |2014-03-11   | 500
2  |2014-04-15   | 1500

Ожидаемый результат:

id | value_list
------------------------
1  | [10, 5, 15, 20]
2  | [100, 500, 1500]

Значения в списке сортируются по дате.

Я попытался использовать collect_list следующим образом:

from pyspark.sql import functions as F
ordered_df = input_df.orderBy(['id','date'],ascending = True)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))

Но collect_list не гарантирует порядок, даже если я сортирую кадр входных данных по дате перед агрегацией.

Может ли кто-нибудь помочь в том, как сделать агрегацию, сохраняя заказ на основе второй (даты) переменной?

Ответ 1

Если вы собираете обе даты и значения в виде списка, вы можете отсортировать полученный столбец в соответствии с использованием даты и udf, а затем сохранить только значения в результате.

import operator
import pyspark.sql.functions as F

# create list column
grouped_df = input_df.groupby("id") \
               .agg(F.collect_list(F.struct("date", "value")) \
               .alias("list_col"))

# define udf
def sorter(l):
  res = sorted(l, key=operator.itemgetter(0))
  return [item[1] for item in res]

sort_udf = F.udf(sorter)

# test
grouped_df.select("id", sort_udf("list_col") \
  .alias("sorted_list")) \
  .show(truncate = False)
+---+----------------+
|id |sorted_list     |
+---+----------------+
|1  |[10, 5, 15, 20] |
|2  |[100, 500, 1500]|
+---+----------------+

Ответ 2

from pyspark.sql import functions as F
from pyspark.sql import Window

w = Window.partitionBy('id').orderBy('date')

sorted_list_df = input_df.withColumn(
            'sorted_list', F.collect_list('value').over(w)
        )\
        .groupBy('id')\
        .agg(F.max('sorted_list').alias('sorted_list'))

Примеры Window предоставленные пользователями, часто не объясняют, что происходит, поэтому позвольте мне проанализировать его для вас.

Как вы знаете, использование collect_list вместе с groupBy приведет к неупорядоченному списку значений. Это связано с тем, что в зависимости от того, как ваши данные разбиты на разделы, Spark добавит значения в ваш список, как только найдет строку в группе. Порядок тогда зависит от того, как Spark планирует вашу агрегацию над исполнителями.

Функция " Window позволяет вам контролировать эту ситуацию, группируя строки на определенное значение, чтобы вы могли выполнять операцию over каждой из результирующих групп:

w = Window.partitionBy('id').orderBy('date')
  • partitionBy - вам нужны группы/разделы строк с одинаковым id
  • orderBy - вы хотите, чтобы каждая строка в группе сортировалась по date

После того, как вы определили область своего окна - "строки с одним и тем же id, отсортированные по date ", вы можете использовать его для выполнения над ним операции, в этом случае - collect_list:

F.collect_list('value').over(w)

На этом этапе вы создали новый столбец sorted_list с упорядоченным списком значений, отсортированным по дате, но у вас все еще есть дублированные строки на один id. Чтобы обрезать дублированные строки, вы хотите groupBy id и сохранить max значение для каждой группы:

.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))

Ответ 3

Вопрос был в PySpark, но может быть полезен и для Scala Spark.

Позвольте подготовить тестовый файл:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{ Window, UserDefinedFunction}

import java.sql.Date
import java.time.LocalDate

val spark: SparkSession = ...

// Out test data set
val data: Seq[(Int, Date, Int)] = Seq(
  (1, Date.valueOf(LocalDate.parse("2014-01-03")), 10),
  (1, Date.valueOf(LocalDate.parse("2014-01-04")), 5),
  (1, Date.valueOf(LocalDate.parse("2014-01-05")), 15),
  (1, Date.valueOf(LocalDate.parse("2014-01-06")), 20),
  (2, Date.valueOf(LocalDate.parse("2014-02-10")), 100),
  (2, Date.valueOf(LocalDate.parse("2014-02-11")), 500),
  (2, Date.valueOf(LocalDate.parse("2014-02-15")), 1500)
)

// Create dataframe
val df: DataFrame = spark.createDataFrame(data)
  .toDF("id", "date", "value")
df.show()
//+---+----------+-----+
//| id|      date|value|
//+---+----------+-----+
//|  1|2014-01-03|   10|
//|  1|2014-01-04|    5|
//|  1|2014-01-05|   15|
//|  1|2014-01-06|   20|
//|  2|2014-02-10|  100|
//|  2|2014-02-11|  500|
//|  2|2014-02-15| 1500|
//+---+----------+-----+

Использовать UDF

// Group by id and aggregate date and value to new column date_value
val grouped = df.groupBy(col("id"))
  .agg(collect_list(struct("date", "value")) as "date_value")
grouped.show()
grouped.printSchema()
// +---+--------------------+
// | id|          date_value|
// +---+--------------------+
// |  1|[[2014-01-03,10],...|
// |  2|[[2014-02-10,100]...|
// +---+--------------------+

// udf to extract data from Row, sort by needed column (date) and return value
val sortUdf: UserDefinedFunction = udf((rows: Seq[Row]) => {
  rows.map { case Row(date: Date, value: Int) => (date, value) }
    .sortBy { case (date, value) => date }
    .map { case (date, value) => value }
})

// Select id and value_list
val r1 = grouped.select(col("id"), sortUdf(col("date_value")).alias("value_list"))
r1.show()
// +---+----------------+
// | id|      value_list|
// +---+----------------+
// |  1| [10, 5, 15, 20]|
// |  2|[100, 500, 1500]|
// +---+----------------+

Использовать окно

val window = Window.partitionBy(col("id")).orderBy(col("date"))
val sortedDf = df.withColumn("values_sorted_by_date", collect_list("value").over(window))
sortedDf.show()
//+---+----------+-----+---------------------+
//| id|      date|value|values_sorted_by_date|
//+---+----------+-----+---------------------+
//|  1|2014-01-03|   10|                 [10]|
//|  1|2014-01-04|    5|              [10, 5]|
//|  1|2014-01-05|   15|          [10, 5, 15]|
//|  1|2014-01-06|   20|      [10, 5, 15, 20]|
//|  2|2014-02-10|  100|                [100]|
//|  2|2014-02-11|  500|           [100, 500]|
//|  2|2014-02-15| 1500|     [100, 500, 1500]|
//+---+----------+-----+---------------------+

val r2 = sortedDf.groupBy(col("id"))
  .agg(max("values_sorted_by_date").as("value_list")) 
r2.show()
//+---+----------------+
//| id|      value_list|
//+---+----------------+
//|  1| [10, 5, 15, 20]|
//|  2|[100, 500, 1500]|
//+---+----------------+

Ответ 4

Чтобы убедиться, что сортировка выполняется для каждого идентификатора, мы можем использовать sortWithinPartitions:

from pyspark.sql import functions as F
ordered_df = (
    input_df
        .repartition(input_df.id)
        .sortWithinPartitions(['date'])


)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))

Ответ 5

В дополнение к тому, что сказал ShadyStego, я тестировал использование sortWithinPartitions и GroupBy в Spark, обнаружив, что он работает намного лучше, чем функции Window или UDF. Тем не менее, при использовании этого метода существует проблема с неправильным порядком раз для каждого раздела, но ее можно легко решить. Я показываю это здесь Spark (pySpark) groupBy неправильно упорядочивает первый элемент в collect_list.

Этот метод особенно полезен для больших DataFrames, но может потребоваться большое количество разделов, если у вас недостаточно памяти драйвера.