Spark:使用单键RDDjoin2元组密钥RDD的最佳策略是什么?

我有两个我想join的RDD,他们看起来像这样:

val rdd1:RDD[(T,U)] val rdd2:RDD[((T,W), V)] 

恰好是rdd1是唯一的,并且rdd1的元组键值是唯一的。 我想join两个数据集,以便获得以下rdd:

 val rdd_joined:RDD[((T,W), (U,V))] 

什么是最有效的方法来实现呢? 以下是我想到的一些想法。

选项1:

 val m = rdd1.collectAsMap val rdd_joined = rdd2.map({case ((t,w), u) => ((t,w), u, m.get(t))}) 

选项2:

 val distinct_w = rdd2.map({case ((t,w), u) => w}).distinct val rdd_joined = rdd1.cartesian(distinct_w).join(rdd2) 

选项1将收集所有的数据,掌握,对吧? 所以,如果rdd1很大(在我的情况下它相对较大,虽然比rdd2小一个数量级)似乎不是一个好的select。 选项2做了一个丑陋的独特的笛卡尔产品,这看起来效率很低。 另一种可能性是我想到的(但还没有尝试过),就是做选项1并播放地图,不过最好以“智能”的方式进行广播,以便地图的键与rdd2键。

有没有人遇到过这种情况? 我很乐意有你的想法。

谢谢!

一种select是通过将rdd1收集到驱动器并将其广播给所有映射器来执行广播连接; 这样可以避免昂贵的大型rdd2 RDD洗牌:

 val rdd1 = sc.parallelize(Seq((1, "A"), (2, "B"), (3, "C"))) val rdd2 = sc.parallelize(Seq(((1, "Z"), 111), ((1, "ZZ"), 111), ((2, "Y"), 222), ((3, "X"), 333))) val rdd1Broadcast = sc.broadcast(rdd1.collectAsMap()) val joined = rdd2.mapPartitions({ iter => val m = rdd1Broadcast.value for { ((t, w), u) <- iter if m.contains(t) } yield ((t, w), (u, m.get(t).get)) }, preservesPartitioning = true) 

preservesPartitioning = true告诉Spark这个map函数不会修改rdd2的键; 这将允许Spark避免为基于(t, w)键的任何后续操作重新分区rdd2

这种广播可能是低效的,因为它涉及到司机的沟通瓶颈。 原则上,可以在不涉及驾驶员的情况下将一个RDD广播给另一个。 我有一个这样的原型,我想概括并添加到Spark。

另一个select是重新映射rdd2的键并使用Spark rdd2方法; 这将涉及rdd2 (也可能是rdd1 )的全面洗牌:

 rdd1.join(rdd2.map { case ((t, w), u) => (t, (w, u)) }).map { case (t, (v, (w, u))) => ((t, w), (u, v)) }.collect() 

在我的示例input中,这两个方法产生相同的结果:

 res1: Array[((Int, java.lang.String), (Int, java.lang.String))] = Array(((1,Z),(111,A)), ((1,ZZ),(111,A)), ((2,Y),(222,B)), ((3,X),(333,C))) 

第三种select是重构rdd2 ,使得t是它的关键,然后执行上面的连接。

另一种方法是创build一个自定义的分区,然后使用zipPartition来join你的RDD。

 import org.apache.spark.HashPartitioner class RDD2Partitioner(partitions: Int) extends HashPartitioner(partitions) { override def getPartition(key: Any): Int = key match { case k: Tuple2[Int, String] => super.getPartition(k._1) case _ => super.getPartition(key) } } val numSplits = 8 val rdd1 = sc.parallelize(Seq((1, "A"), (2, "B"), (3, "C"))).partitionBy(new HashPartitioner(numSplits)) val rdd2 = sc.parallelize(Seq(((1, "Z"), 111), ((1, "ZZ"), 111), ((1, "AA"), 123), ((2, "Y"), 222), ((3, "X"), 333))).partitionBy(new RDD2Partitioner(numSplits)) val result = rdd2.zipPartitions(rdd1)( (iter2, iter1) => { val m = iter1.toMap for { ((t: Int, w), u) <- iter2 if m.contains(t) } yield ((t, w), (u, m.get(t).get)) } ).partitionBy(new HashPartitioner(numSplits)) result.glom.collect