在Spark RDD和/或Spark DataFrame中重塑/旋转数据

我有以下格式的数据(RDD或Spark DataFrame):

from pyspark.sql import SQLContext sqlContext = SQLContext(sc) rdd = sc.parallelize([('X01',41,'US',3), ('X01',41,'UK',1), ('X01',41,'CA',2), ('X02',72,'US',4), ('X02',72,'UK',6), ('X02',72,'CA',7), ('X02',72,'XX',8)]) # convert to a Spark DataFrame schema = StructType([StructField('ID', StringType(), True), StructField('Age', IntegerType(), True), StructField('Country', StringType(), True), StructField('Score', IntegerType(), True)]) df = sqlContext.createDataFrame(rdd, schema) 

我想做的是“重塑”数据,将国家(特别是美国,英国和加拿大)的某些行转换为列:

 ID Age US UK CA 'X01' 41 3 1 2 'X02' 72 4 6 7 

基本上,我需要沿着Python的pivot工作stream程的一些东西:

 categories = ['US', 'UK', 'CA'] new_df = df[df['Country'].isin(categories)].pivot(index = 'ID', columns = 'Country', values = 'Score') 

我的数据集相当大,所以我不能真正collect()并将数据提取到内存中,以便在Python本身中进行重塑。 有没有办法将Python的.pivot()转换为可调用函数,同时映射RDD或Spark DataFrame? 任何帮助将不胜感激!

首先,这可能不是一个好主意,因为你没有得到任何额外的信息,但是你用一个固定的模式来约束自己(也就是说,你必须知道你期望有多less个国家,当然还有更多的国家手段改变代码)

话虽如此,这是一个SQL问题,如下所示。 但是假如你以为不是太“软件”(认真听说这个!!),那么你可以参考第一个解决scheme。

解决scheme1:

 def reshape(t): out = [] out.append(t[0]) out.append(t[1]) for v in brc.value: if t[2] == v: out.append(t[3]) else: out.append(0) return (out[0],out[1]),(out[2],out[3],out[4],out[5]) def cntryFilter(t): if t[2] in brc.value: return t else: pass def addtup(t1,t2): j=() for k,v in enumerate(t1): j=j+(t1[k]+t2[k],) return j def seq(tIntrm,tNext): return addtup(tIntrm,tNext) def comb(tP,tF): return addtup(tP,tF) countries = ['CA', 'UK', 'US', 'XX'] brc = sc.broadcast(countries) reshaped = calls.filter(cntryFilter).map(reshape) pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1) for i in pivot.collect(): print i 

现在,解决scheme2:当然,更好的SQL是适合这个的工具

 callRow = calls.map(lambda t: Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3])) callsDF = ssc.createDataFrame(callRow) callsDF.printSchema() callsDF.registerTempTable("calls") res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\ from (select userid,age,\ case when country='CA' then nbrCalls else 0 end ca,\ case when country='UK' then nbrCalls else 0 end uk,\ case when country='US' then nbrCalls else 0 end us,\ case when country='XX' then nbrCalls else 0 end xx \ from calls) x \ group by userid,age") res.show() 

数据设置:

 data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)] calls = sc.parallelize(data,1) countries = ['CA', 'UK', 'US', 'XX'] 

结果:

从第一个解决scheme

 (('X02', 72), (7, 6, 4, 8)) (('X01', 41), (2, 1, 3, 0)) 

从第二解决scheme:

 root |-- age: long (nullable = true) |-- country: string (nullable = true) |-- nbrCalls: long (nullable = true) |-- userid: string (nullable = true) userid age ca uk us xx X02 72 7 6 4 8 X01 41 2 1 3 0 

请让我知道如果这个工程,或不:)

最好的Ayan

由于Spark 1.6可以在GroupedData上使用pivot函数,并提供聚合expression式。

 pivoted = (df .groupBy("ID", "Age") .pivot( "Country", ['US', 'UK', 'CA']) # Optional list of levels .sum("Score")) # alternatively you can use .agg(expr)) pivoted.show() ## +---+---+---+---+---+ ## | ID|Age| US| UK| CA| ## +---+---+---+---+---+ ## |X01| 41| 3| 1| 2| ## |X02| 72| 4| 6| 7| ## +---+---+---+---+---+ 

级别可以省略,但如果提供可以提高性能和作为一个内部filter。

这个方法还是比较慢的,但是肯定会在JVM和Python之间手动传递数据。

这里有一个本地Spark方法,不硬连接列名。 它基于aggregateByKey ,并使用字典来收集每个键出现的列。 然后我们收集所有的列名称来创build最终的数据框。 [之前的版本在为每个logging发出一个字典后使用了jsonRDD,但这样做效率更高。]限制到特定的列列表,或者排除像XX这样的列将是一个简单的修改。

即使在相当大的桌子上,performance也不错。 我正在使用一种变化来计算每个ID的每个可变数量的事件发生的次数,每个事件types生成一个列。 代码基本上是相同的,只不过在seqFn使用了collections.Counter而不是dict来计算出现次数。

 from pyspark.sql.types import * rdd = sc.parallelize([('X01',41,'US',3), ('X01',41,'UK',1), ('X01',41,'CA',2), ('X02',72,'US',4), ('X02',72,'UK',6), ('X02',72,'CA',7), ('X02',72,'XX',8)]) schema = StructType([StructField('ID', StringType(), True), StructField('Age', IntegerType(), True), StructField('Country', StringType(), True), StructField('Score', IntegerType(), True)]) df = sqlCtx.createDataFrame(rdd, schema) def seqPivot(u, v): if not u: u = {} u[v.Country] = v.Score return u def cmbPivot(u1, u2): u1.update(u2) return u1 pivot = ( df .rdd .keyBy(lambda row: row.ID) .aggregateByKey(None, seqPivot, cmbPivot) ) columns = ( pivot .values() .map(lambda u: set(u.keys())) .reduce(lambda s,t: s.union(t)) ) result = sqlCtx.createDataFrame( pivot .map(lambda (k, u): [k] + [u.get(c) for c in columns]), schema=StructType( [StructField('ID', StringType())] + [StructField(c, IntegerType()) for c in columns] ) ) result.show() 

生产:

 ID CA UK US XX X02 7 6 4 8 X01 2 1 3 null 

所以首先,我必须对你的RDD进行修正(与你的实际输出相匹配):

 rdd = sc.parallelize([('X01',41,'US',3), ('X01',41,'UK',1), ('X01',41,'CA',2), ('X02',72,'US',4), ('X02',72,'UK',6), ('X02',72,'CA',7), ('X02',72,'XX',8)]) 

一旦我做了这个改正,这个诀窍就是:

 df.select($"ID", $"Age").groupBy($"ID").agg($"ID", first($"Age") as "Age") .join( df.select($"ID" as "usID", $"Country" as "C1",$"Score" as "US"), $"ID" === $"usID" and $"C1" === "US" ) .join( df.select($"ID" as "ukID", $"Country" as "C2",$"Score" as "UK"), $"ID" === $"ukID" and $"C2" === "UK" ) .join( df.select($"ID" as "caID", $"Country" as "C3",$"Score" as "CA"), $"ID" === $"caID" and $"C3" === "CA" ) .select($"ID",$"Age",$"US",$"UK",$"CA") 

当然,不像你的支点一样优雅。

只是对patricksurry非常有帮助的回答一些意见:

  • 列Age缺失,所以只需将u [“Age”] = v.Age添加到函数seqPivot
  • 事实certificate,在列的元素上的两个循环以不同的顺序给元素。 列的值是正确的,但不是它们的名字。 为了避免这种行为,只需排列列表。

这里是稍微修改的代码:

 from pyspark.sql.types import * rdd = sc.parallelize([('X01',41,'US',3), ('X01',41,'UK',1), ('X01',41,'CA',2), ('X02',72,'US',4), ('X02',72,'UK',6), ('X02',72,'CA',7), ('X02',72,'XX',8)]) schema = StructType([StructField('ID', StringType(), True), StructField('Age', IntegerType(), True), StructField('Country', StringType(), True), StructField('Score', IntegerType(), True)]) df = sqlCtx.createDataFrame(rdd, schema) # u is a dictionarie # v is a Row def seqPivot(u, v): if not u: u = {} u[v.Country] = v.Score # In the original posting the Age column was not specified u["Age"] = v.Age return u # u1 # u2 def cmbPivot(u1, u2): u1.update(u2) return u1 pivot = ( rdd .map(lambda row: Row(ID=row[0], Age=row[1], Country=row[2], Score=row[3])) .keyBy(lambda row: row.ID) .aggregateByKey(None, seqPivot, cmbPivot) ) columns = ( pivot .values() .map(lambda u: set(u.keys())) .reduce(lambda s,t: s.union(t)) ) columns_ord = sorted(columns) result = sqlCtx.createDataFrame( pivot .map(lambda (k, u): [k] + [u.get(c, None) for c in columns_ord]), schema=StructType( [StructField('ID', StringType())] + [StructField(c, IntegerType()) for c in columns_ord] ) ) print result.show() 

最后,输出应该是

 +---+---+---+---+---+----+ | ID|Age| CA| UK| US| XX| +---+---+---+---+---+----+ |X02| 72| 7| 6| 4| 8| |X01| 41| 2| 1| 3|null| +---+---+---+---+---+----+ 

在Hive中有一个JIRA,可以在本地执行此操作,而不需要为每个值声明一个巨大的CASE语句:

https://issues.apache.org/jira/browse/HIVE-3776

请将JIRA投票,以便尽早实施。 一旦在Hive SQL中,Spark通常不会缺less太多的东西,最终也会在Spark中实现。