Spark中的分层采样

我有数据集,其中包含用户和购买数据。 这里是一个例子,其中第一个元素是userId,第二个是productId,第三个是布尔值。

(2147481832,23355149,1) (2147481832,973010692,1) (2147481832,2134870842,1) (2147481832,541023347,1) (2147481832,1682206630,1) (2147481832,1138211459,1) (2147481832,852202566,1) (2147481832,201375938,1) (2147481832,486538879,1) (2147481832,919187908,1) ... 

我想确保每个用户数据只占80%,build立一个RDD,剩下的20%,build立另一个RDD。 让电话训练和testing。 我想远离使用groupBy开始,因为它可以创build内存问题,因为数据集很大。 最好的办法是做到这一点?

我可以做以下,但这不会给每个用户的80%。

 val percentData = data.map(x => ((math.random * 100).toInt, x._1. x._2, x._3) val train = percentData.filter(x => x._1 < 80).values.repartition(10).cache() 

霍尔登的答案中有一种可能性,这是另一种可能性:

您可以使用PairRDDFunctions类中的sampleByKeyExact转换。

sampleByKeyExact (布尔withReplacement,scala.collection.Map分数,长种子)返回这个RDD的一个子集,通过密钥(通过分层采样),包含准确的math.ceil(numItems * samplingRate)为每个层(具有相同的密钥)。

这就是我将要做的事情:

考虑以下列表:

 val list = List((2147481832,23355149,1),(2147481832,973010692,1),(2147481832,2134870842,1),(2147481832,541023347,1),(2147481832,1682206630,1),(2147481832,1138211459,1),(2147481832,852202566,1),(2147481832,201375938,1),(2147481832,486538879,1),(2147481832,919187908,1),(214748183,919187908,1),(214748183,91187908,1)) 

我会创build一个RDD Pair,将所有用户映射为键:

 val data = sc.parallelize(list.toSeq).map(x => (x._1,(x._2,x._3))) 

然后,我将为每个键设置fractions ,如下所示,因为您已经注意到sampleByKeyExact为每个键获取分数的Map:

 val fractions = data.map(_._1).distinct.map(x => (x,0.8)).collectAsMap 

我在这里所做的是,实际上,映射的关键find不同的,然后将每个关键的关联分数等于0.8,然后我收集整体作为一个地图。

现在,我要做的就是:

 import org.apache.spark.rdd.PairRDDFunctions val sampleData = data.sampleByKeyExact(false, fractions, 2L) 

要么

 val sampleData = data.sampleByKeyExact(withReplacement = false, fractions = fractions,seed = 2L) 

您可以检查您的钥匙或数据或数据样本的数量:

 scala > data.count // [...] // res10: Long = 12 scala > sampleData.count // [...] // res11: Long = 10 

编辑:我决定添加一个部分来执行DataFrame的分层采样。

所以我们会考虑上面例子中的相同数据( list )。

 val df = list.toDF("keyColumn","value1","value2") df.show // +----------+----------+------+ // | keyColumn| value1|value2| // +----------+----------+------+ // |2147481832| 23355149| 1| // |2147481832| 973010692| 1| // |2147481832|2134870842| 1| // |2147481832| 541023347| 1| // |2147481832|1682206630| 1| // |2147481832|1138211459| 1| // |2147481832| 852202566| 1| // |2147481832| 201375938| 1| // |2147481832| 486538879| 1| // |2147481832| 919187908| 1| // | 214748183| 919187908| 1| // | 214748183| 91187908| 1| // +----------+----------+------+ 

我们将需要底层的RDD来做到这一点,我们通过定义我们的密钥成为第一列来创buildRDD中元素的元组:

 val data: RDD[(Int, Row)] = df.rdd.keyBy(_.getInt(0)) val fractions: Map[Int, Double] = data.map(_._1) .distinct .map(x => (x, 0.8)) .collectAsMap val sampleData: RDD[Row] = data.sampleByKeyExact(withReplacement = false, fractions, 2L) .values val sampleDataDF: DataFrame = spark.createDataFrame(sampleData, df.schema) // you can use sqlContext.createDataFrame(...) instead for spark 1.6) 

您现在可以检查您的密钥或df或数据示例的计数:

 scala > df.count // [...] // res9: Long = 12 scala > sampleDataDF.count // [...] // res10: Long = 10 

编辑2:自Spark 1.5.0以来,您可以使用DataFrameStatFunctions.sampleBy方法:

 df.stat.sampleBy("keyColumn", fractions, seed) 

像这样的东西可能非常适合像“Blink DB”,但让我们看看这个问题。 有两种方法来解释你所问的是:

1)你想要80%的用户,而且你想得到所有的数据。 2)你想要每个用户数据的80%

对于#1你可以做一个地图来获得用户id,调用不同的,然后对其中的80%进行采样(你可能想在MLUtils或者BernoulliCellSampler查看kFold )。 然后,您可以将input数据过滤为所需的一组ID。

对于#2你可以看看BernoulliCellSampler ,直接应用它。