上节课当中,给大家介绍了这个项目需要用到的数据集,并且做了简单的数据预处理,缓存好了关系分类文件。接下来,我们可以定义Dataset类,来加载数据了。

但这个模型的输入参数和目标值比较复杂,我们拆分成三节课来处理。这节课,先完成文件加载和分词这两块内容。

代码示例

1、添加配置项

# config.py
TRAIN_JSON_PATH = './data/input/duie/duie_train.json'
TEST_JSON_PATH = './data/input/duie/duie_test.json'
DEV_JSON_PATH = './data/input/duie/duie_dev.json'

BERT_MODEL_NAME = 'bert-base-chinese'

2、新建文件

# utils.py
import torch.utils.data as data
import pandas as pd
import random
from config import *
import json
from transformers import BertTokenizerFast
3、加载关系表
def get_rel():
    df = pd.read_csv(REL_PATH, names=['rel', 'id'])
    return df['rel'].tolist(), dict(df.values)

4、Dataset初始化

class Dataset(data.Dataset):
    def __init__(self, type='train'):
        super().__init__()
        _, self.rel2id = get_rel()
        # 加载文件
        if type == 'train':
            file_path = TRAIN_JSON_PATH
        elif type == 'test':
            file_path = TEST_JSON_PATH
        elif type == 'dev':
            file_path = DEV_JSON_PATH
        with open(file_path) as f:
            self.lines = f.readlines()
        # 加载bert
        self.tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, index):
        line = self.lines[index]
        info = json.loads(line)
        tokenized = self.tokenizer(info['text'], return_offsets_mapping=True)
        info['input_ids'] = tokenized['input_ids']
        info['offset_mapping'] = tokenized['offset_mapping']
        print(info)
        exit()

5、尝试加载数据集

if __name__ == '__main__':
    dataset = Dataset()
    loader = data.DataLoader(dataset)
    print(iter(loader).next())

本文链接:http://ichenhua.cn/edu/note/480

版权声明:本文为「陈华编程」原创课程讲义,请给与知识创作者起码的尊重,未经许可不得传播或转售!