直接上代码:
package part
import org.apache.log4j.{
Level, Logger}
import org.apache.spark.{
Partitioner, SparkConf, SparkContext}
object Spark_Part {
def main(args: Array[String]): Unit = {
//屏蔽日志信息
Logger.getLogger("org").setLevel(Level.ERROR)
//创建sparkconf
val conf = new SparkConf().setMaster("local[2]").setAppName("wc")
//创建spark程序入口
val sc = new SparkContext(conf)
//创建集合对象
val list = List(("nba","************"),("cba","************"),
("wnba","************"),("nba","************"))
//将集合对象写进RDD里 并创建三个分区
val inputRDD = sc.makeRDD(list,3)
//将新的RDD使用partitionby方法自定义分区
val value = inputRDD.partitionBy(new Mypartitioner)
//保存到文件里
value.saveAsTextFile("output")
sc.stop()
}
/**
* 第一 : 自定义分区器
* 第二 : 重写方法
*/
class Mypartitioner extends Partitioner{
//分区数量
override def numPartitions: Int = 3
//根据数据的key值 返回数据所在的分区索引 (从0开始)
override def getPartition(key: Any): Int = {
//方式一 : 用if做判断
// if(key == "nba"){
// 0
// }else if (key == "cba"){
// 1
// }else{
// 2
// }
//方式二 : 用模式匹配
//如果是nba 放到0号分区,如果是cba 放到1号分区,如果是其他,放到2号分区
key match {
case "nba" => 0
case "cba" => 1
case _ => 2
}
}
}
}