007、自然语言处理基础:RNN、LSTM与文本分类实战

张开发
2026/4/12 1:57:14 15 分钟阅读

分享文章

007、自然语言处理基础:RNN、LSTM与文本分类实战
007、自然语言处理基础RNN、LSTM与文本分类实战一、从一段诡异的输出说起上周调一个文本分类模型训练集准确率冲到98%测试集死活卡在70%。损失函数曲线像过山车——训练损失一路向下验证损失中途反弹。盯着屏幕半小时突然意识到问题我忘了对文本做长度截断导致batch内样本长度差异极大padding过多模型在训练时过度依赖padding位置的信息。这种问题在NLP任务里太常见了。今天我们就从最基础的RNN聊起看看如何处理序列数据以及如何用LSTM解决长序列问题最后落地到一个真实的文本分类任务上。二、RNN记忆单元与它的局限性RNN的核心思想很简单让网络具备“记忆”。当前时刻的输出不仅取决于当前输入还取决于上一时刻的隐藏状态。代码实现大概长这样classNaiveRNN(nn.Module):def__init__(self,input_size,hidden_size):super().__init__()self.hidden_sizehidden_size# 这里有个细节Wxh和Whh分开写还是合并实际用nn.Linear就行self.i2hnn.Linear(input_sizehidden_size,hidden_size)self.i2onn.Linear(input_sizehidden_size,output_size)defforward(self,x,hidden):combinedtorch.cat((x,hidden),1)# 拼接当前输入和上一时刻隐藏状态hiddentorch.tanh(self.i2h(combined))# 计算新隐藏状态outputself.i2o(combined)# 计算输出returnoutput,hidden看起来很美对吧但实际跑起来问题就来了。梯度在时间维度上反向传播时需要连乘多个雅可比矩阵。当序列较长时梯度要么爆炸梯度数值溢出要么消失梯度趋近于零。我早期做过一个实验用RNN处理超过50个时间步的文本后20个时间步的梯度基本为零——模型根本学不到长距离依赖。三、LSTM三个门与细胞状态LSTM通过引入门控机制和细胞状态来解决长程依赖问题。关键就三个门遗忘门、输入门、输出门。很多人第一次看LSTM结构图会懵其实理解核心就行# 这是LSTM的核心计算步骤不是完整实现deflstm_cell_forward(xt,h_prev,c_prev,parameters):# 遗忘门决定丢掉哪些信息ftsigmoid(Wf [h_prev,xt]bf)# 这个sigmoid输出在0~1之间# 输入门决定更新哪些信息itsigmoid(Wi [h_prev,xt]bi)c_candidatetanh(Wc [h_prev,xt]bc)# 候选细胞状态# 更新细胞状态ctft*c_previt*c_candidate# 关键这里是乘法不是拼接# 输出门otsigmoid(Wo [h_prev,xt]bo)htot*tanh(ct)returnht,ct注意细胞状态ct的更新公式ft * c_prev it * c_candidate。这个设计很妙——细胞状态像一条传送带梯度可以几乎无损地穿过多个时间步。遗忘门的sigmoid输出接近1时信息长期保留接近0时信息被丢弃。四、实战用LSTM做文本分类直接上干货说几个实际编码时容易踩的坑。4.1 数据预处理部分defbuild_vocab(texts,max_vocab_size20000):构建词表这里有个经验值文本分类任务2万词表足够覆盖95%以上词汇word_counterCounter()fortextintexts:# 中文记得先分词英文要转小写tokensjieba.lcut(text)ifis_chineseelsetext.lower().split()word_counter.update(tokens)# 保留高频词留出三个特殊tokenmost_commonword_counter.most_common(max_vocab_size-3)vocab{pad:0,unk:1,cls:2}# cls是分类标记vocab.update({word:i3fori,(word,_)inenumerate(most_common)})returnvocab4.2 模型定义关键点classTextLSTM(nn.Module):def__init__(self,vocab_size,embed_dim128,hidden_dim256,num_classes10):super().__init__()self.embeddingnn.Embedding(vocab_size,embed_dim,padding_idx0)# 注意这里batch_firstTrue这样输入形状就是(batch, seq_len, embed_dim)# 早期PyTorch默认是(seq_len, batch, embed_dim)容易搞混self.lstmnn.LSTM(embed_dim,hidden_dim,num_layers2,# 两层LSTM效果通常比一层好bidirectionalTrue,# 双向能捕捉上下文信息batch_firstTrue,dropout0.5iftrainingelse0)# 训练时dropout# 分类头self.fcnn.Sequential(nn.Dropout(0.5),nn.Linear(hidden_dim*2,128),# 双向所以*2nn.ReLU(),nn.Linear(128,num_classes))defforward(self,x,lengths):# x形状: (batch, seq_len)embeddedself.embedding(x)# (batch, seq_len, embed_dim)# 重点使用pack_padded_sequence处理变长序列packedpack_padded_sequence(embedded,lengths.cpu(),batch_firstTrue,enforce_sortedFalse)packed_out,(hidden,cell)self.lstm(packed)# 取最后时刻的隐藏状态双向LSTM需要拼接前后向hiddentorch.cat([hidden[-2],hidden[-1]],dim1)# (batch, hidden_dim*2)returnself.fc(hidden)这里踩过大坑如果不使用pack_padded_sequence模型会在padding位置进行无意义计算不仅浪费算力还会干扰学习。enforce_sortedFalse是PyTorch 1.7之后才有的之前需要手动排序。4.3 训练技巧# 学习率调整策略schedulertorch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,modemax,factor0.5,patience2,verboseTrue)# 监控验证集准确率而不是损失分类任务这样更直观# 梯度裁剪防止爆炸torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm5.0)# 早停策略ifval_accbest_acc:best_accval_acc patience_counter0torch.save(model.state_dict(),best_model.pth)else:patience_counter1ifpatience_counter5:# 连续5轮没提升就停止print(早停触发)break五、一些血泪经验文本长度分布很重要训练前务必统计文本长度分布95%分位数作为max_length。我做过电商评论分类95%长度在200词内但剩下5%有上千词——直接截断到200效果比用动态RNN好。预训练词向量要不要用中文任务上如果数据量小于10万条用word2vec或GloVe预训练词向量能提升2-5个点。数据量大时从头训练embedding也行。LSTM的dropout用法特殊LSTM有两处dropout层间dropout和输出dropout。nn.LSTM的dropout参数指的是层间dropout输出dropout要自己加在LSTM后面。双向LSTM的最后一层隐藏状态取hidden[-2]和hidden[-1]分别代表前向和后向的最后一个有效时间步。别取hidden[:,-1,:]那是padding后的时间步。调试时先过拟合一个小batch取32条样本训练到100%准确率。如果做不到要么模型容量不够要么代码有bug。工业场景的trick上线时用ONNX导出模型推理速度能提升30%。LSTM在CPU上并行效果不好考虑用CNN或Transformer替代。六、最后说两句NLP任务没有银弹。去年我们团队用BERT横扫所有文本分类任务但今年在某些实时性要求高的场景又换回了LSTM——毕竟LSTM在CPU上的推理速度还是快不少。技术选型时想清楚你的场景是追求准确率还是推理速度数据量有多大标注成本多少新手常犯的错误是“模型崇拜”一上来就怼最复杂的结构。实际项目中先跑通baseline比如TF-IDFLR再加复杂度。我见过有人用BERT做二分类准确率只比LSTM高0.3%但推理慢了20倍——这性价比太低了。保持简单直到简单不够用。共勉。

更多文章