DL4J模型训练Word2Vec

       目前深度学习已经越来越受到重视,深度学习的框架也是层出不穷,例如谷歌的TensorFlow,它是基于Python进行开发的,对于许多对Python不够了解的程序员来说用起来可能没有那么的方便,这里说一下一个基于Java的深度学习框架——DL4J。本博客主要介绍在代码层面基于DL4J模型实现训练Word2Vec,一起来看一下吧~


【代码】

package com.xzw.dl4j;

import java.io.File;
import java.io.IOException;
import java.util.Collection;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.EndingPreProcessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
/**
 * 
 * @author xzw
 *
 */
public class Word2VecTest {
	@SuppressWarnings("deprecation")
	public static void main(String[] args) throws IOException {
		System.out.println("Load data...");
		File file = new File("C://Users//Machenike//Desktop//zzz//raw_sentences.txt");
		SentenceIterator iterator = new LineSentenceIterator(file);
		iterator.setPreProcessor(new SentencePreProcessor() {
			
			private static final long serialVersionUID = 1L;

			@Override
			public String preProcess(String sentence) {
				// TODO Auto-generated method stub
				return sentence.toLowerCase();
			}
		});
		System.out.println("Tokenize data...");
		final EndingPreProcessor preProcessor = new EndingPreProcessor();
		TokenizerFactory tokenizer = new DefaultTokenizerFactory();
		tokenizer.setTokenPreProcessor(new TokenPreProcess() {
			
			@Override
			public String preProcess(String token) {
				// TODO Auto-generated method stub
				token = token.toLowerCase();
				String base = preProcessor.preProcess(token);
				base = base.replaceAll("\\d", "d");
				return base;
			}
		});
		
		System.out.println("Build model...");
		int batchSize = 1000;
		int iterations = 3;
		int layerSize = 150;
		Word2Vec vec = new Word2Vec.Builder()
			.batchSize(batchSize)
			.minWordFrequency(5)
			.useAdaGrad(false)
			.layerSize(layerSize)
			.iterations(iterations)
			.learningRate(0.025)
			.minLearningRate(1e-3)
			.negativeSample(10)
			.iterate(iterator)
			.tokenizerFactory(tokenizer)
			.build();
		//train
		System.out.println("Learning...");
		vec.fit();
		//model save
		System.out.println("Save model...");
		WordVectorSerializer.writeWordVectors(vec, "C://Users//Machenike//Desktop//zzz//words.txt");
		System.out.println("Evaluate model...");
		String word1 = "people";
		String word2 = "money";
		double similarity = vec.similarity(word1, word2);
		System.out.println(String.format("The similarity between %s and %s is %f", word1, word2, similarity));
		String word = "day";
		int ranking = 10;
		Collection<String> similarTop10 = vec.wordsNearest(word, ranking);
		System.out.println(String.format("Similar word to %s is %s", word, similarTop10));
	}

}

【用到的数据集】


【保存的Word2Vec模型】


【运行结果】








猜你喜欢

转载自blog.csdn.net/gdkyxy2013/article/details/80151423