torchtext库(文本预处理库)

使用参考:https://zhuanlan.zhihu.com/p/31139113

例程:

def get_data_iter(train_csv, test_csv, fix_length, batch_size, word2vec_dir):
    TEXT = data.Field(sequential=True, lower=True, fix_length=fix_length, batch_first=True)
    LABEL = data.Field(sequential=False, use_vocab=False)
    train_fields = [("label", LABEL), ("title", None), ("text", TEXT)]
    train = TabularDataset(path=train_csv, format='csv', fields=train_fields, skip_header=True)
    train_iter = BucketIterator(train, batch_size=batch_size, device=-1, sort_key=lambda x : len(x.text), sort_within_batch=False, repeat=False)
    test_fields = [("label", LABEL), ("title", None), ("text", TEXT)]
    test = TabularDataset(path=test_csv, format="csv", fields=test_fields, skip_header=True)
    test_iter = Iterator(test, batch_size=batch_size,device=-1, sort=False, sort_within_batch=False, repeat=False)
    #vectors = Vectors(name=word2vec_dir)
    #TEXT.build_vocab(train, vectors=vectors)
    TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300))
    vocab = TEXT.vocab
    return train_iter, test_iter, vocab

猜你喜欢

转载自www.cnblogs.com/zf-blog/p/12621007.html
今日推荐