Кэширование промежуточных результатов в трубопроводе Spark ML

В последнее время я планирую перенести мой автономный код ML Python на искровой. Труба ML в spark.ml оказывается весьма удобной, с упрощенным API для объединения этапов алгоритма и поиска по гиперпараметрам.

Тем не менее, я нашел поддержку одной важной функции, которая неясна в существующих документах: кэширование промежуточных результатов. Важность этой функции возникает, когда трубопровод включает стадии интенсивного вычисления.

Например, в моем случае я использую огромную разреженную матрицу для выполнения нескольких скользящих средних по данным временных рядов для формирования входных функций. Структура матрицы определяется некоторым гиперпараметром. Этот шаг оказывается узким местом для всего конвейера, потому что я должен построить матрицу во время выполнения.

Во время поиска параметров у меня обычно есть другие параметры, кроме этого "структурного параметра". Поэтому, если я могу повторно использовать огромную матрицу, когда "структурный параметр" не изменяется, я могу сэкономить массу времени. По этой причине я намеренно сформировал свой код для кэширования и повторного использования этих промежуточных результатов.

Итак, мой вопрос: Может ли Spark ML конвейер обрабатывать промежуточное кэширование автоматически? Или мне нужно вручную создать код для этого? Если да, есть ли какая-нибудь лучшая практика, чтобы учиться?

P.S. Я просмотрел официальный документ и некоторые другие материалы, но ни один из них, похоже, не обсуждает эту тему.

Ответ 1

Итак, я столкнулся с той же проблемой, и так, как я решил, я реализовал свой собственный PipelineStage, который кэширует входной DataSet и возвращает его как есть.

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType

class Cacher(val uid: String) extends Transformer with DefaultParamsWritable {
  override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF.cache()

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = schema

  def this() = this(Identifiable.randomUID("CacherTransformer"))
}

Чтобы использовать его, вы сделали бы что-то вроде этого:

new Pipeline().setStages(Array(stage1, new Cacher(), stage2))