上节课,给大家介绍了 TextCNN 的模型结构,这节课就正式进入代码部分。本节课有两个任务,一是导入数据集,二是要统计待分类的文本长度,因为 TextCNN 在卷积之后,要做批量最大池化操作,所以要求文本长度一致,不够的填充PAD,太长的要进行截取。

代码示例

1、添加配置项

# config.py
TRAIN_SAMPLE_PATH = './data/input/train.txt'
DEV_SAMPLE_PATH = './data/input/dev.txt'
TEST_SAMPLE_PATH = './data/input/test.txt'

LABEL_PATH = './data/input/class.txt'

BERT_PAD_ID = 0
TEXT_LEN = 35

2、统计句子长度

from config import *

import matplotlib.pyplot as plt

def count_text_len():
    text_len = []
    with open(TRAIN_SAMPLE_PATH) as f:
        for line in f.readlines():
            text, _ = line.split('\t')
            text_len.append(len(text))
    plt.hist(text_len)
    plt.show()
    print(max(text_len))

if __name__ == '__main__':
    count_text_len()

做完简单的数据预处理之后,下一步就要定义 Dataset 类,加载数据了。但是数据加载,涉及到 Bert 分词,可能有的同学对 Bert 的使用还不熟悉,所以下面两节课,我们对 Bert 的基本使用做一个简单介绍,已经会用的同学可以直接跳过。

本文链接:http://ichenhua.cn/edu/note/503

版权声明:本文为「陈华编程」原创课程讲义,请给与知识创作者起码的尊重,未经许可不得传播或转售!