如何透视DataFrame?
我开始使用Spark Dataframes,我需要能够转换数据以创build具有多行的1列中的多个列。 在Scalding中有内置的function,我相信Python中的Pandas,但是我找不到新的Spark Dataframe。
我认为我可以编写自定义函数,但是我不知道如何开始,尤其是因为我是一个Spark新手。 我有人知道如何做到这一点内置的function或如何写在斯卡拉的东西的build议,不胜感激。
正如 @ user2000823 所提到的那样, Spark从版本1.6开始提供了pivot
函数。 一般语法如下所示:
df .groupBy(grouping_columns) .pivot(pivot_column, [values]) .agg(aggregate_expressions)
使用nycflights13
和spark-csv
用法示例:
Python :
from pyspark.sql.functions import avg flights = (sqlContext .read .format("com.databricks.spark.csv") .options(inferSchema="true", header="true") .load("flights.csv") .na.drop()) flights.registerTempTable("flights") sqlContext.cacheTable("flights") gexprs = ("origin", "dest", "carrier") aggexpr = avg("arr_delay") flights.count() ## 336776 %timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count() ## 10 loops, best of 3: 1.03 s per loop
斯卡拉 :
val flights = sqlContext .read .format("com.databricks.spark.csv") .options(Map("inferSchema" -> "true", "header" -> "true")) .load("flights.csv") flights .groupBy($"origin", $"dest", $"carrier") .pivot("hour") .agg(avg($"arr_delay"))
R :
library(magrittr) flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE) flights %>% groupBy("origin", "dest", "carrier") %>% pivot("hour") %>% agg(avg(column("arr_delay")))
性能考虑 :
一般来说,旋转是一个昂贵的操作。
-
如果你可以尝试提供
values
列表:vs = list(range(25)) %timeit -n10 flights.groupBy(*gexprs ).pivot("hour", vs).agg(aggexpr).count() ## 10 loops, best of 3: 392 ms per loop
-
在某些情况下,certificate有利于
repartition
和/或预先汇总数据 -
如果只是进行整形,可以使用Pyspark Dataframe上的
first
: 透视string列
我通过编写for循环来dynamic创build一个SQL查询来克服这个问题。 说我有:
id tag value 1 US 50 1 UK 100 1 Can 125 2 US 75 2 UK 150 2 Can 175
而且我要:
id US UK Can 1 50 100 125 2 75 150 175
我可以用我想要的值来创build一个列表,然后创build一个包含我需要的SQL查询的string。
val countries = List("US", "UK", "Can") val numCountries = countries.length - 1 var query = "select *, " for (i <- 0 to numCountries-1) { query += """case when tag = """" + countries(i) + """" then value else 0 end as """ + countries(i) + ", " } query += """case when tag = """" + countries.last + """" then value else 0 end as """ + countries.last + " from myTable" myDataFrame.registerTempTable("myTable") val myDF1 = sqlContext.sql(query)
我可以创build类似的查询,然后做聚合。 不是一个非常优雅的解决scheme,但它的工作原理和灵活的值的任何列表,也可以作为parameter passing时调用您的代码。
一个pivot操作符已经被添加到Spark数据框API中,并且是Spark 1.6的一部分。
有关详细信息,请参阅https://github.com/apache/spark/pull/7841 。
我已经通过以下步骤解决了使用数据框的类似问题:
为所有国家创build列,值为“值”:
import org.apache.spark.sql.functions._ val countries = List("US", "UK", "Can") val countryValue = udf{(countryToCheck: String, countryInRow: String, value: Long) => if(countryToCheck == countryInRow) value else 0 } val countryFuncs = countries.map{country => (dataFrame: DataFrame) => dataFrame.withColumn(country, countryValue(lit(country), df("tag"), df("value"))) } val dfWithCountries = Function.chain(countryFuncs)(df).drop("tag").drop("value")
你的数据框“dfWithCountries”将如下所示:
+--+--+---+---+ |id|US| UK|Can| +--+--+---+---+ | 1|50| 0| 0| | 1| 0|100| 0| | 1| 0| 0|125| | 2|75| 0| 0| | 2| 0|150| 0| | 2| 0| 0|175| +--+--+---+---+
现在,您可以将所有的值汇总在一起:
dfWithCountries.groupBy("id").sum(countries: _*).show
结果:
+--+-------+-------+--------+ |id|SUM(US)|SUM(UK)|SUM(Can)| +--+-------+-------+--------+ | 1| 50| 100| 125| | 2| 75| 150| 175| +--+-------+-------+--------+
这不是一个非常优雅的解决scheme。 我必须创build一个函数链来添加所有的列。 另外,如果我有很多国家的话,我会把我的临时数据集扩大到很多很多的零。
最初我采用了Al M的解决scheme。 后来采取了同样的想法,并重写了这个function作为转置function。
此方法使用键和值列将任何df行转换为任何数据格式的列
inputcsv
id,tag,value 1,US,50a 1,UK,100 1,Can,125 2,US,75 2,UK,150 2,Can,175
输出中
+--+---+---+---+ |id| UK| US|Can| +--+---+---+---+ | 2|150| 75|175| | 1|100|50a|125| +--+---+---+---+
转置方法:
def transpose(hc : HiveContext , df: DataFrame,compositeId: List[String], key: String, value: String) = { val distinctCols = df.select(key).distinct.map { r => r(0) }.collect().toList val rdd = df.map { row => (compositeId.collect { case id => row.getAs(id).asInstanceOf[Any] }, scala.collection.mutable.Map(row.getAs(key).asInstanceOf[Any] -> row.getAs(value).asInstanceOf[Any])) } val pairRdd = rdd.reduceByKey(_ ++ _) val rowRdd = pairRdd.map(r => dynamicRow(r, distinctCols)) hc.createDataFrame(rowRdd, getSchema(df.schema, compositeId, (key, distinctCols))) } private def dynamicRow(r: (List[Any], scala.collection.mutable.Map[Any, Any]), colNames: List[Any]) = { val cols = colNames.collect { case col => r._2.getOrElse(col.toString(), null) } val array = r._1 ++ cols Row(array: _*) } private def getSchema(srcSchema: StructType, idCols: List[String], distinctCols: (String, List[Any])): StructType = { val idSchema = idCols.map { idCol => srcSchema.apply(idCol) } val colSchema = srcSchema.apply(distinctCols._1) val colsSchema = distinctCols._2.map { col => StructField(col.asInstanceOf[String], colSchema.dataType, colSchema.nullable) } StructType(idSchema ++ colsSchema) }
主要的片段
import java.util.Date import org.apache.spark.SparkConf import org.apache.spark.SparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.StructField ... ... def main(args: Array[String]): Unit = { val sc = new SparkContext(conf) val sqlContext = new org.apache.spark.sql.SQLContext(sc) val dfdata1 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true") .load("data.csv") dfdata1.show() val dfOutput = transpose(new HiveContext(sc), dfdata1, List("id"), "tag", "value") dfOutput.show }