博客
关于我
2.1 TextCNN - 二元情感分类 - 文本分类的卷积神经网络
阅读量:798 次
发布时间:2023-04-17

本文共 4059 字,大约阅读时间需要 13 分钟。

TextCNN: 卷积神经网络在文本分类中的应用

卷积神经网络(CNN)作为一种强大的深度学习模型,已成功应用于文本分类任务中。通过使用多个不同尺寸的卷积核,TextCNN能够有效提取句子中的局部关键信息,类似于多窗口大小的ngram模型,从而更好地捕捉文本的局部相关性。

TextCNN由Yoon Kim于2014年在《Convolutional Neural Networks for Sentence Classification》一文中提出,旨在解决文本分类问题。该模型对单词的表示方式为:使用一个k维向量来表示单词在句子中的位置信息。

与传统的词袋模型不同,TextCNN采用了卷积操作来提取特征。由于卷积操作具有局部感受野的特性,TextCNN在处理文本时只能上下滑动卷积核,而不能左右滑动。这是因为如果左右滑动,单词的重叠处理会导致卷积结果难以有效表示句子结构,影响模型性能。

TextCNN的结构详解

第一层:输入层

输入层是一个7x5的词向量矩阵,每个词的向量维度为5,共7个单词。

第二层:卷积层

卷积层包含6个卷积核,分别对应不同尺寸的filter(2x5、3x5、4x5,每个尺寸各有2个)。每个卷积核会对输入层进行卷积操作,生成对应的特征图(feature map),然后通过激活函数处理。

第三层:池化层

池化层采用1-max pooling机制,从每个特征图中提取最大值,形成6维的特征表示。这种池化方式既保留了局部最大值信息,又减少了维度,提高了计算效率。

第四层:输出层

输出层是一个分类层,使用softamax激活函数进行分类。该层还支持L2正则化,以防止过拟合。

TextCNN的细节介绍

特征(feature)

TextCNN中的特征分为静态和非静态两种。静态特征通常通过预训练词向量(如Word2vec或Glove)提供,非静态特征则在训练过程中进行微调。推荐使用非静态细化(fine-tuning)方式,即在训练前对词向量进行预训练,然后在模型训练中进行调整,以加速收敛。

通道(channel)

类似于图像中的RGB色彩通道,TextCNN中的通道通常采用不同的嵌入方式。实践中常将静态词向量和非静态词向量分开作为不同通道,以增强模型的表达能力。

一维卷积(conv-1d)

TextCNN采用一维卷积(conv-1d),通过不同尺寸的filter获取不同宽度的感受野。这种设计使得模型能够捕捉到不同长度的局部信息。

1-max池化(1-max pooling)

在TextCNN中,池化层通常采用1-max pooling策略。当然,也可以使用动态k-max pooling策略,将k个最大值保留下来,以更好地保留全局信息。

TextCNN的参数设置

  • 序列长度:通常设置为最大句子长度。
  • 类别数量:预测的分类类别数量。
  • 字典大小:词汇数量。
  • 嵌入长度:每个词表示的向量维度,可使用Word2vec、Fasttext、Glove等工具进行预训练。
  • 卷积核大小:对应n元语法的概念。
  • 卷积核个数:与卷积核大小对应的卷积核数量。

TextCNN的实现代码示例

import tensorflow as tfimport numpy as nptf.reset_default_graph()# TextCNN参数设置embedding_size = 2  # 词嵌入维度sequence_length = 3  # 最大句子长度num_classes = 2  # 类别数量(0或1)filter_sizes = [2, 2, 2]  # n-gram窗口大小num_filters = 3  # 卷积核个数# 输入数据sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]labels = [1, 1, 1, 0, 0, 0]  # 1表示好,0表示坏# 创建词典word_list = " ".join(sentences).split()word_list = list(set(word_list))word_dict = {w: i for i, w in enumerate(word_list)}vocab_size = len(word_dict)# 输入向量inputs = []for sen in sentences:    inputs.append(np.asarray([word_dict[n] for n in sen.split()]))# 输出向量(one-hot编码)outputs = []for out in labels:    outputs.append(np.eye(num_classes)[out])# 图像输入占位符X = tf.placeholder(tf.int32, [None, sequence_length])Y = tf.placeholder(tf.int32, [None, num_classes])# 词嵌入层W = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0))embedded_chars = tf.nn.embedding_lookup(W, X)# 添加通道维度embedded_chars = tf.expand_dims(embedded_chars, -1)  # [batch_size, sequence_length, embedding_size, 1]# 卷积层pooled_outputs = []for i, filter_size in enumerate(filter_sizes):    filter_shape = [filter_size, embedding_size, 1, num_filters]    W_conv = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1))    b_conv = tf.Variable(tf.constant(0.1, shape=[num_filters]))    # 卷积操作    conv = tf.nn.conv2d(embedded_chars, W_conv, strides=[1, 1, 1, 1], padding='VALID')    h = tf.nn.relu(tf.nn.bias_add(conv, b_conv))    # 池化操作    pooled = tf.nn.max_pool(h, ksize=[1, sequence_length - filter_size + 1, 1, 1], strides=[1, 1, 1, 1], padding='VALID')    pooled_outputs.append(pooled)# 合并池化输出num_filters_total = num_filters * len(filter_sizes)h_pool = tf.concat(pooled_outputs, num_filters)# 展平成向量h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total])# 全连接层W = tf.get_variable('W', shape=[num_filters_total, num_classes], initializer=tf.contrib.layers.xavier_initializer())b = tf.Variable(tf.constant(0.1, shape=[num_classes]))model = tf.nn.xw_plus_b(h_pool_flat, W, b)# 交叉熵损失函数cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=model, labels=Y))# 优化器optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)# 初始化init = tf.global_variables_initializer()# 会话创建sess = tf.Session()sess.run(init)# 训练过程for epoch in range(5000):    _, loss = sess.run([optimizer, cost], feed_dict={X: inputs, Y: outputs})    if (epoch + 1) % 1000 == 0:        print('Epoch: %06d, Loss: %.6f' % (epoch + 1, loss))# 测试test_text = 'she loves you'tests = np.asarray([word_dict[n] for n in test_text.split()])predict = sess.run([tf.argmax(model, 1)], feed_dict={X: tests})result = predict[0][0]if result == 0:    print(test_text, "is Bad Mean...")else:    print(test_text, "is Good Mean!!")

转载地址:http://ulgfk.baihongyu.com/

你可能感兴趣的文章
mysql中kill掉所有锁表的进程
查看>>
mysql中like % %模糊查询
查看>>
MySql中mvcc学习记录
查看>>
mysql中null和空字符串的区别与问题!
查看>>