pytorch之tensor矩阵输出省略问题

import torch
from transformers import BertConfig, BertModel, BertTokenizer


if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained('./bert-base-uncased')
    model_config = BertConfig.from_pretrained('./bert-base-uncased')
    model = BertModel.from_pretrained('./bert-base-uncased',config=model_config)

    texts = ["[CLS] Who was Jim Henson [SEP]",
             "[CLS] Jim Henson was a puppeteer [SEP]"]

    tokens, segments, input_masks =[], [], []
    for text in texts:
        tokenized_text = tokenizer.tokenize(text)
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
        tokens.append(indexed_tokens)
        segments.append( [0]*len(indexed_tokens) )
        input_masks.append( [1]*len(indexed_tokens) )
    # print(tokens)
    # print(segments)
    # print(input_masks)

    max_len = max([len(single) for single in tokens])  # 最大的句子长度

    for j in range(len(tokens)):
        padding = [0] * (max_len - len(tokens[j]))
        tokens[j] += padding
        segments[j] += padding
        input_masks[j] += padding

    # device = torch.cuda.current_device()

    tokens_tensor = torch.tensor(tokens)
    segments_tensors = torch.tensor(segments)
    input_masks_tensors = torch.tensor(input_masks)

    # output = model(tokens_tensor)
    output = model(tokens_tensor, segments_tensors, input_masks_tensors)
    sequence_output = output[0]
    pooled_output = output[1] # CLS
    torch.set_printoptions(edgeitems=768)# 设置输出矩阵维度为768

    with open("test.txt", 'a', encoding='utf-8') as f:
        f.write("sequence_output:")
        f.write(str(sequence_output))
        f.write('\n')
        f.write("pooled_output:")
        f.write(str(pooled_output))

关键函数 torch.set_printoptions()

 torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None)

precision是每一个元素的输出精度,默认是八位;
threshold是输出时的阈值,当tensor中元素的个数大于该值时,进行缩略输出,默认时1000;
edgeitems是输出的维度,默认是3;
linewidth字面意思,每一行输出的长度;
profile=None,修正默认设置。

发布了41 篇原创文章 · 获赞 44 · 访问量 7653

猜你喜欢

转载自blog.csdn.net/tailonh/article/details/105185650
今日推荐