MapReduce结合MongoDB(实现从MongoDB读写数据)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/qq_18505209/article/details/100627912
需求

统计students每个年龄段人数,将结果写入res

数据

students集:
students
res集:
res

代码

MongoDBTest.java:

package MapReduce07;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.log4j.BasicConfigurator;
import org.bson.Document;

import java.io.IOException;
import java.util.Map;
import java.util.TreeMap;

public class MongoDBTest {
    public static class MyMapper extends Mapper<LongWritable, Document, IntWritable, IntWritable>{
        @Override
        protected void map(LongWritable key, Document value, Context context) throws IOException, InterruptedException {
            if(value.get("age")==null){
                return;
                //System.out.println("No Data Found!");
            }
            //转int
            System.out.println("-----------------------------"+value.get("age").getClass().getName());
            //System.out.println("-----------------------------"+value.get("name").getClass().getName());
            int age = Double.valueOf(value.get("age").toString()).intValue();
            context.write(new IntWritable(age), new IntWritable(1));
        }
    }
    public static class MyReducer extends Reducer<IntWritable, IntWritable, Document, NullWritable>{
        Document doc = new Document();
        @Override
        protected void reduce(IntWritable key, Iterable<IntWritable> values, Context context) throws IOException, InterruptedException {
            int num = 0;
            for(IntWritable value:values){
                num+=value.get();
            }
            Map<String, Integer> map = new TreeMap<String, Integer>();
            map.put("age is "+key, num);
            //map.put("aa", num*10);
            doc.putAll(map);
        }
        @Override
        protected void cleanup(Context context) throws IOException, InterruptedException {
            //写入
            if(doc.isEmpty()){
                return;
            }
            context.write(doc, NullWritable.get());
        }
    }
    public static void main(String[] args) throws Exception {
        BasicConfigurator.configure();
        Configuration conf = new Configuration();
        conf.set("input","localhost://mydb.students");
        conf.set("output","localhost://mydb.res");

        Job job = Job.getInstance(conf);
        job.setJarByClass(MongoDBTest.class);

        job.setMapperClass(MyMapper.class);
        job.setMapOutputKeyClass(IntWritable.class);
        job.setMapOutputValueClass(IntWritable.class);

        job.setReducerClass(MyReducer.class);
        job.setOutputKeyClass(Document.class);
        job.setOutputValueClass(NullWritable.class);

        job.setInputFormatClass(MongoDBInputFormat.class);
        job.setOutputFormatClass(MongoDBOutputFormat.class);

        System.exit(job.waitForCompletion(true) ? 0 : 1);
    }
}

MongoDBInputFormat.java:

package MapReduce07;

import com.mongodb.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoCursor;
import com.mongodb.client.MongoDatabase;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.*;
import org.bson.Document;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class MongoDBInputFormat  extends InputFormat <LongWritable, Document>{

    public MongoDBInputFormat(){}
    //自定义切片类
    public static class MongoDBInputSplit extends InputSplit implements Writable{
        private long start;
        private long end;
        public MongoDBInputSplit(){}
        public MongoDBInputSplit(long start, long end) {
            this.start = start;
            this.end = end;
        }
        @Override
        public void write(DataOutput out) throws IOException {
            out.writeLong(start);
            out.writeLong(end);
        }

        @Override
        public void readFields(DataInput in) throws IOException {
            this.start = in.readLong();
            this.end = in.readLong();
        }

        @Override
        public long getLength() throws IOException, InterruptedException {
            return this.end-this.start;
        }

        @Override
        public String[] getLocations() throws IOException, InterruptedException {
            return new String[0];
        }
    }
    @Override
    public List<InputSplit> getSplits(JobContext context) throws IOException, InterruptedException {
        //获取mongodb的连接
        String url = context.getConfiguration().get("input");
        String[] url_s = url.split("://");
        String dbName = url_s[1].split("\\.")[0];
        String collectionName = url_s[1].split("\\.")[1];

        MongoClient client = new MongoClient(url_s[0],27017);
        MongoDatabase db = client.getDatabase(dbName);
        MongoCollection<Document> collection = db.getCollection(collectionName);
        //获取collection的总记录数
        long count = collection.count();
        //定义分片大小
        long chunk = 2;
        //计算分片个数
        long chunksize = (count / chunk);
        //定义存储分片的集合
        List<InputSplit> list = new ArrayList<InputSplit>();
        //循环分片,一个分片chunk条数据
        for (int i = 0; i < chunksize; i++) {
            MongoDBInputSplit mis = null;
            if(i+1 == chunksize){
                mis = new MongoDBInputSplit(i*chunk, count);
                list.add(mis);
            } else {
                mis = new MongoDBInputSplit(i*chunk, i*chunk + chunk);
                list.add(mis);
            }
        }
        return list;
    }

    @Override
    public RecordReader<LongWritable, Document> createRecordReader(InputSplit split, TaskAttemptContext context) throws IOException, InterruptedException {
        return new MongoDBRecordReader(split, context);
    }
    public  static class MongoDBRecordReader extends RecordReader<LongWritable, Document>{
        private MongoDBInputSplit split;
        //结果集
        private MongoCursor<Document> dbcursor;
        //索引,每次都会被初始化成0,只读取当前切片中的k,v
        private int index;
        //偏移量,再下面会自动封装成切片数据的开始,就会知道读多少行 ,对应map泛型的第一个值
        private LongWritable key;
        //每次读到的结果,会通过返回出去,对应  map泛型的第二个
        private Document value;
        //数据库信息
        String ip;
        String dbName;
        String collectionName;

        public MongoDBRecordReader(){}
        public MongoDBRecordReader(InputSplit split,TaskAttemptContext context) throws IOException, InterruptedException{
            super();
            initialize(split, context);

            //获取mongodb的连接
            String url = context.getConfiguration().get("input");
            String[] url_s = url.split("://");
            dbName = url_s[1].split("\\.")[0];
            collectionName = url_s[1].split("\\.")[1];
            ip = url_s[0];
        }
        //初始化,将一些对象new出来,并把得到的切片(1个)强转
        public void initialize(InputSplit split, TaskAttemptContext context) throws IOException, InterruptedException {
            this.split = (MongoDBInputSplit)split;
            this.key =  new LongWritable();
            this.value = new Document();
        }

        //读取数据,并把数据封装到当前MongoDBRecordReader的k v中
        @Override
        public boolean nextKeyValue() throws IOException, InterruptedException {
            //判断dbcursor是否为null
            if(this.dbcursor == null){
                //获取dbcursor的值
                //获取集合
                MongoClient client = new MongoClient(ip, 27017);
                MongoDatabase db = client.getDatabase(dbName);
                MongoCollection<Document> collection = db.getCollection(collectionName);
                //获取结果集
                dbcursor = collection.find().skip((int) this.split.start).limit((int) this.split.getLength()).iterator();
            }
            //判断
            boolean hasNext = this.dbcursor.hasNext();
            if(hasNext){
                //key
                this.key.set(this.split.start + index);
                index ++;
                //value
                Document next = this.dbcursor.next();
                this.value = next;
            }
            return hasNext;
        }

        @Override
        public LongWritable getCurrentKey() throws IOException, InterruptedException {
            return this.key;
        }

        @Override
        public Document getCurrentValue() throws IOException, InterruptedException {
            return this.value;
        }

        @Override
        public float getProgress() throws IOException, InterruptedException {
            return 0;
        }

        @Override
        public void close() throws IOException {

        }
    }
}

MongoDBOutputFormat.java:

package MapReduce07;

import com.mongodb.MongoClient;
import com.mongodb.client.MongoCollection;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapreduce.*;
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter;
import org.bson.Document;
import java.io.IOException;
import java.util.Map;

public class MongoDBOutputFormat extends OutputFormat<Document, NullWritable> {

    @Override
    public RecordWriter<Document, NullWritable> getRecordWriter(TaskAttemptContext context) throws IOException, InterruptedException {
        return new MongoDBRecordWriter(context);
    }
    public static class MongoDBRecordWriter extends RecordWriter<Document, NullWritable> {
        public MongoCollection<Document> collection  = null;
        public MongoDBRecordWriter(){}
        public MongoDBRecordWriter(TaskAttemptContext context){
            //获取mongodb的连接
            String uri = context.getConfiguration().get("output");
            String[] datas = uri.split("://");
            String ip = datas[0];
            String dbsName = datas[1].split("\\.")[0];
            String tableName = datas[1].split("\\.")[1];
            MongoClient client = new MongoClient(ip,27017);
            collection = client.getDatabase(dbsName).getCollection(tableName);
        }
        public void write(Document key, NullWritable value) throws IOException, InterruptedException {
            for(Map.Entry<String, Object> entry : key.entrySet()){
                System.out.println(entry.getKey()+entry.getValue());
            }
            collection.insertOne(new Document(key));
            //collection.insertMany();
        }
        public void close(TaskAttemptContext context) throws IOException, InterruptedException {

        }
    }
    @Override
    public void checkOutputSpecs(JobContext context) throws IOException, InterruptedException {

    }
    @Override
    public OutputCommitter getOutputCommitter(TaskAttemptContext context) throws IOException, InterruptedException {
        return new FileOutputCommitter(null, context);
    }
}

猜你喜欢

转载自blog.csdn.net/qq_18505209/article/details/100627912