Как получить другие столбцы при использовании Spark DataFrame groupby?

когда я использую группу DataFrame следующим образом:

df.groupBy(df("age")).agg(Map("id"->"count"))

Я получаю только DataFrame с столбцами "age" и "count (id)", но в df есть много других столбцов, таких как "name".

В целом, я хочу получить результат как в MySQL,

"выберите имя, возраст, счетчик (id) из группы df по возрасту"

Что делать, если вы используете groupby в Spark?

Ответ 1

Короче говоря, вам нужно объединить результаты с исходной таблицей. Spark SQL следует тому же пред-SQL: 1999 как большинство основных баз данных (PostgreSQL, Oracle, MS SQL Server), который не позволяет добавлять дополнительные столбцы в запросы агрегации.

Так как для агрегаций, таких как результаты подсчета, недостаточно четко определены, и поведение, как правило, меняется в системах, поддерживающих этот тип запросов, вы можете просто добавить дополнительные столбцы, используя произвольный агрегат, например first или last.

В некоторых случаях вы можете заменить agg на select на функции окна и последующие where, но в зависимости от контекста это может быть довольно дорого.

Ответ 2

Один из способов получить все столбцы после выполнения groupBy - использовать функцию соединения.

feature_group = ['name', 'age']
data_counts = df.groupBy(feature_group).count().alias("counts")
data_joined = df.join(data_counts, feature_group)

data_joined теперь будет иметь все столбцы, включая значения count.

Ответ 3

Может быть, это решение будет полезным.

from pyspark.sql import SQLContext
from pyspark import SparkContext, SparkConf
from pyspark.sql import functions as F
from pyspark.sql import Window

    name_list = [(101, 'abc', 24), (102, 'cde', 24), (103, 'efg', 22), (104, 'ghi', 21),
                 (105, 'ijk', 20), (106, 'klm', 19), (107, 'mno', 18), (108, 'pqr', 18),
                 (109, 'rst', 26), (110, 'tuv', 27), (111, 'pqr', 18), (112, 'rst', 28), (113, 'tuv', 29)]

age_w = Window.partitionBy("age")
name_age_df = sqlContext.createDataFrame(name_list, ['id', 'name', 'age'])

name_age_count_df = name_age_df.withColumn("count", F.count("id").over(age_w)).orderBy("count")
name_age_count_df.show()

Выход:

+---+----+---+-----+
| id|name|age|count|
+---+----+---+-----+
|109| rst| 26|    1|
|113| tuv| 29|    1|
|110| tuv| 27|    1|
|106| klm| 19|    1|
|103| efg| 22|    1|
|104| ghi| 21|    1|
|105| ijk| 20|    1|
|112| rst| 28|    1|
|101| abc| 24|    2|
|102| cde| 24|    2|
|107| mno| 18|    3|
|111| pqr| 18|    3|
|108| pqr| 18|    3|
+---+----+---+-----+

Ответ 4

Агрегатные функции уменьшают значения строк для указанных столбцов в группе. Если вы хотите сохранить другие значения строк, вам необходимо реализовать логику сокращения, которая определяет строку, из которой происходит каждое значение. Например, сохраните все значения первого ряда с максимальным значением возраста. Для этого вы можете использовать UDAF (пользовательскую статистическую функцию), чтобы уменьшить количество строк в группе.

import org.apache.spark.sql._
import org.apache.spark.sql.functions._


object AggregateKeepingRowJob {

  def main (args: Array[String]): Unit = {

    val sparkSession = SparkSession
      .builder()
      .appName(this.getClass.getName.replace("$", ""))
      .master("local")
      .getOrCreate()

    val sc = sparkSession.sparkContext
    sc.setLogLevel("ERROR")

    import sparkSession.sqlContext.implicits._

    val rawDf = Seq(
      (1L, "Moe",  "Slap",  2.0, 18),
      (2L, "Larry",  "Spank",  3.0, 15),
      (3L, "Curly",  "Twist", 5.0, 15),
      (4L, "Laurel", "Whimper", 3.0, 15),
      (5L, "Hardy", "Laugh", 6.0, 15),
      (6L, "Charley",  "Ignore",   5.0, 5)
    ).toDF("id", "name", "requisite", "money", "age")

    rawDf.show(false)
    rawDf.printSchema

    val maxAgeUdaf = new KeepRowWithMaxAge

    val aggDf = rawDf
      .groupBy("age")
      .agg(
        count("id"),
        max(col("money")),
        maxAgeUdaf(
          col("id"),
          col("name"),
          col("requisite"),
          col("money"),
          col("age")).as("KeepRowWithMaxAge")
      )

    aggDf.printSchema
    aggDf.show(false)

  }


}

UDAF:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
  StructType(
    StructField("store", StringType) ::
    StructField("prod", StringType) ::
    StructField("amt", DoubleType) ::
    StructField("units", IntegerType) :: Nil
  )

// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
  StructField("store", StringType) ::
  StructField("prod", StringType) ::
  StructField("amt", DoubleType) ::
  StructField("units", IntegerType) :: Nil
)


// This is the output type of your aggregation function.
override def dataType: DataType =
  StructType((Array(
    StructField("store", StringType),
    StructField("prod", StringType),
    StructField("amt", DoubleType),
    StructField("units", IntegerType)
  )))

override def deterministic: Boolean = true

// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
  buffer(0) = ""
  buffer(1) = ""
  buffer(2) = 0.0
  buffer(3) = 0
}

// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

  val amt = buffer.getAs[Double](2)
  val candidateAmt = input.getAs[Double](2)

  amt match {
    case a if a < candidateAmt =>
      buffer(0) = input.getAs[String](0)
      buffer(1) = input.getAs[String](1)
      buffer(2) = input.getAs[Double](2)
      buffer(3) = input.getAs[Int](3)
    case _ =>
  }
}

// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

  buffer1(0) = buffer2.getAs[String](0)
  buffer1(1) = buffer2.getAs[String](1)
  buffer1(2) = buffer2.getAs[Double](2)
  buffer1(3) = buffer2.getAs[Int](3)
}

// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
  buffer
}
}

Ответ 5

Вот пример, с которым я столкнулся в искрового мастерской

val populationDF = spark.read
                .option("infer-schema", "true")
                .option("header", "true")
                .format("csv").load("file:///databricks/driver/population.csv")
                .select('name, regexp_replace(col("population"), "\\s", "").cast("integer").as("population"))

val maxPopulationDF = populationDF.agg(max('population).as("populationmax"))

Чтобы получить другие столбцы, я делаю простое соединение между исходным DF и агрегированным

populationDF.join(maxPopulationDF,populationDF.col("population") === maxPopulationDF.col("populationmax")).select('name, 'populationmax).show()

Ответ 6

Вы должны помнить, что агрегатные функции сокращают строки, и поэтому вам нужно указать, какое из названий строк вы хотите использовать с помощью сокращающей функции. Если вы хотите сохранить все строки группы (предупреждение! Это может привести к взрывам или перекосам), вы можете собрать их в виде списка. Затем вы можете использовать UDF (пользовательскую функцию), чтобы уменьшить их по вашим критериям, в моем примере деньги. А затем разверните столбцы из одной уменьшенной строки с помощью другой пользовательской функции. Для целей этого ответа я предполагаю, что вы хотите сохранить имя человека, у которого больше всего денег.

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StringType

import scala.collection.mutable


object TestJob3 {

def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  (1, "Moe",  "Slap",  2.0, 18),
  (2, "Larry",  "Spank",  3.0, 15),
  (3, "Curly",  "Twist", 5.0, 15),
  (4, "Laurel", "Whimper", 3.0, 9),
  (5, "Hardy", "Laugh", 6.0, 18),
  (6, "Charley",  "Ignore",   5.0, 5)
).toDF("id", "name", "requisite", "money", "age")

rawDf.show(false)
rawDf.printSchema

val rawSchema = rawDf.schema

val fUdf = udf(reduceByMoney, rawSchema)

val nameUdf = udf(extractName, StringType)

val aggDf = rawDf
  .groupBy("age")
  .agg(
    count(struct("*")).as("count"),
    max(col("money")),
    collect_list(struct("*")).as("horizontal")
  )
  .withColumn("short", fUdf($"horizontal"))
  .withColumn("name", nameUdf($"short"))
  .drop("horizontal")

aggDf.printSchema

aggDf.show(false)

}

def reduceByMoney= (x: Any) => {

val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

val red = d.reduce((r1, r2) => {

  val money1 = r1.getAs[Double]("money")
  val money2 = r2.getAs[Double]("money")

  val r3 = money1 match {
    case a if a >= money2 =>
      r1
    case _ =>
      r2
  }

  r3
})

red
}

def extractName = (x: Any) => {

  val d = x.asInstanceOf[GenericRowWithSchema]

  d.getAs[String]("name")
}
}

вот вывод

+---+-----+----------+----------------------------+-------+
|age|count|max(money)|short                       |name   |
+---+-----+----------+----------------------------+-------+
|5  |1    |5.0       |[6, Charley, Ignore, 5.0, 5]|Charley|
|15 |2    |5.0       |[3, Curly, Twist, 5.0, 15]  |Curly  |
|9  |1    |3.0       |[4, Laurel, Whimper, 3.0, 9]|Laurel |
|18 |2    |6.0       |[5, Hardy, Laugh, 6.0, 18]  |Hardy  |
+---+-----+----------+----------------------------+-------+

Ответ 7

Вы можете сделать следующее:

Пример данных:

name    age id
abc     24  1001
cde     24  1002
efg     22  1003
ghi     21  1004
ijk     20  1005
klm     19  1006
mno     18  1007
pqr     18  1008
rst     26  1009
tuv     27  1010
pqr     18  1012
rst     28  1013
tuv     29  1011
df.select("name","age","id").groupBy("name","age").count().show();

Вывод:

    +----+---+-----+
    |name|age|count|
    +----+---+-----+
    | efg| 22|    1|
    | tuv| 29|    1|
    | rst| 28|    1|
    | klm| 19|    1|
    | pqr| 18|    2|
    | cde| 24|    1|
    | tuv| 27|    1|
    | ijk| 20|    1|
    | abc| 24|    1|
    | mno| 18|    1|
    | ghi| 21|    1|
    | rst| 26|    1|
    +----+---+-----+