CasRel项目 P4 构建Dataset数据集和BERT分词
上节课当中,给大家介绍了这个项目需要用到的数据集,并且做了简单的数据预处理,缓存好了关系分类文件。接下来,我们可以定义Dataset类,来加载数据了。
但这个模型的输入参数和目标值比较复杂,我们拆分成三节课来处理。这节课,先完成文件加载和分词这两块内容。
代码示例
1、添加配置项
# config.py TRAIN_JSON_PATH = './data/input/duie/duie_train.json' TEST_JSON_PATH = './data/input/duie/duie_test.json' DEV_JSON_PATH = './data/input/duie/duie_dev.json' BERT_MODEL_NAME = 'bert-base-chinese'
2、新建文件
# utils.py import torch.utils.data as data import pandas as pd import random from config import * import json from transformers import BertTokenizerFast3、加载关系表
def get_rel(): df = pd.read_csv(REL_PATH, names=['rel', 'id']) return df['rel'].tolist(), dict(df.values)
4、Dataset初始化
class Dataset(data.Dataset): def __init__(self, type='train'): super().__init__() _, self.rel2id = get_rel() # 加载文件 if type == 'train': file_path = TRAIN_JSON_PATH elif type == 'test': file_path = TEST_JSON_PATH elif type == 'dev': file_path = DEV_JSON_PATH with open(file_path) as f: self.lines = f.readlines() # 加载bert self.tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME) def __len__(self): return len(self.lines) def __getitem__(self, index): line = self.lines[index] info = json.loads(line) tokenized = self.tokenizer(info['text'], return_offsets_mapping=True) info['input_ids'] = tokenized['input_ids'] info['offset_mapping'] = tokenized['offset_mapping'] print(info) exit()
5、尝试加载数据集
if __name__ == '__main__': dataset = Dataset() loader = data.DataLoader(dataset) print(iter(loader).next())
本文链接:http://ichenhua.cn/edu/note/480
版权声明:本文为「陈华编程」原创课程讲义,请给与知识创作者起码的尊重,未经许可不得传播或转售!