Фон
Мой первоначальный вопрос: почему использование функции DecisionTreeModel.predict
внутри карты вызывает исключение? и связано с Как генерировать кортежи (оригинальная маска, предсказанная метка) на Spark с MLlib?
Когда мы используем Scala API рекомендуемый способ получения прогнозов для RDD[LabeledPoint]
с помощью DecisionTreeModel
, это просто сопоставить более RDD
:
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
К сожалению, подобный подход в PySpark работает не так хорошо:
labelsAndPredictions = testData.map(
lambda lp: (lp.label, model.predict(lp.features))
labelsAndPredictions.first()
Исключение: похоже, что вы пытаетесь ссылаться на SparkContext из широковещательной переменной, действия или трансформерации. SparkContext можно использовать только в драйвере, а не в коде, который он запускает на рабочих. Для получения дополнительной информации см. SPARK-5063.
Вместо официальная документация рекомендует что-то вроде этого:
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
Итак, что здесь происходит? Здесь нет переменной широковещания, а Scala API определяет predict
следующим образом:
/**
* Predict values for a single data point using the model trained.
*
* @param features array representing a single data point
* @return Double prediction from the trained model
*/
def predict(features: Vector): Double = {
topNode.predict(features)
}
/**
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
* @return RDD of predictions for each of the given data points
*/
def predict(features: RDD[Vector]): RDD[Double] = {
features.map(x => predict(x))
}
поэтому, по крайней мере, на первый взгляд, вызов от действия или преобразования не является проблемой, поскольку предсказание кажется локальной операцией.
Объяснение
После некоторого рытья я понял, что источником проблемы является метод JavaModelWrapper.call
, вызванный из DecisionTreeModel.predict. Это доступ SparkContext
, который требуется для вызова функции Java:
callJavaFunc(self._sc, getattr(self._java_model, name), *a)
Вопрос
В случае DecisionTreeModel.predict
существует рекомендуемое обходное решение, и весь требуемый код уже является частью API Scala, но есть ли элегантный способ справиться с такой проблемой вообще?
Единственные решения, о которых я сейчас думаю, - это тяжелый вес:
- подталкивая все к JVM либо путем расширения классов Spark через Implicit Conversions, либо добавления каких-то оболочек
- напрямую с помощью шлюза Py4j.