TextCNN项目 P3 自定义Dataset和BertTokenizer分词
上节课,我们做了一个简单的数据预处理,通过观察直方图,定义好了文本的长度参数。同时,如果对 Bert 不熟悉的同学,还需要看一下前面补充的内容。
现在,假设大家已经看了、并且掌握了前面 Huggingface 的内容,我们接着往下讲自定义 Dataset 和 Bert 分词的内容。
代码示例
1、新建文件
# utils.py from torch.utils import data from config import * import torch from transformers import BertTokenizer from transformers import logging logging.set_verbosity_error()
2、自定义Dataset类
class Dataset(data.Dataset): def __init__(self, type='train'): super().__init__() if type == 'train': sample_path = TRAIN_SAMPLE_PATH elif type == 'dev': sample_path = DEV_SAMPLE_PATH elif type == 'test': sample_path = TEST_SAMPLE_PATH self.lines = open(sample_path).readlines() self.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL) def __len__(self): return len(self.lines) def __getitem__(self, index): text, label = self.lines[index].split('\t') tokened = self.tokenizer(text) input_ids = tokened['input_ids'] mask = tokened['attention_mask'] if len(input_ids) < TEXT_LEN: pad_len = (TEXT_LEN - len(input_ids)) input_ids += [BERT_PAD_ID] * pad_len mask += [0] * pad_len target = int(label) return torch.tensor(input_ids[:TEXT_LEN]), torch.tensor(mask[:TEXT_LEN]), torch.tensor(target)
3、调用测试
if __name__ == '__main__': dataset = Dataset() loader = data.DataLoader(dataset, batch_size=2) print(iter(loader).next())
本文链接:http://www.ichenhua.cn/edu/note/504
版权声明:本文为「陈华编程」原创课程讲义,请给与知识创作者起码的尊重,未经许可不得传播或转售!