|-转 【A情感文本分类实战】2024 Pytorch+Bert、Roberta+TextCNN、BiLstm、Lstm等实现IMDB情感文本分类完整项目(项目已开源)
原文较长这里只转载前一部分内容 20240325
https://blog.csdn.net/ccaoshangfei/article/details...
https://github.com/BeiCunNan/sentiment_analysis_Im...
?顶会的代码干净利索,借鉴其完成了以下工程
?本工程采用Pytorch框架,使用上游语言模型+下游网络模型的结构实现IMDB情感分析
?预训练大语言模型可选择Bert、Roberta
?下游网络模型可选择BiLSTM、LSTM、TextCNN、GRU、Attention以及其组合
?语言模型和网络模型扩展性较好,可以此为BaseLine再使用你的数据集,模型
?最终的准确率均在90%以上
?项目已开源,clone下来再配个简单环境就能跑
???有小伙伴询问如何融合使用Attention、LSTM+TextCNN和Lstm+TextCNN+Self-Attention的网络模型,现源码已经重新上传(2023-03),大家可以揣摩一下是如何结合的,如此,对照类似的做法,推广到其他模型上
如果这篇文章对您有帮助,期待大佬们Github上给个⭐️⭐️⭐️
一、Introduction
1.1 网络架构图
该网络主要使用上游预训练模型+下游情感分类模型组成
1.2 快速使用
该项目已开源在Github上,地址为 sentiment_analysis_Imdb
主要环境要求如下(环境不要太老基本没啥问题的)
下载该项目后,配置相对应的环境,在config.py文件中选择所需的语言模型和神经网络模型如下图所示,运行main.py文件即可
1.3 工程结构
logs 每次运行程序后的日志文件集合
config.py 全局配置文件
data.py 数据读取、数据清洗、数据格式转换、制作DataSet和DataLoader
main.py 主函数,负责全流程项目运行,包括语言模型的转换,模型的训练和测试
model.py 神经网络模型的设计和读取
二、Config
看了很多论文源代码中都使用parser容器进行全局变量的配置,因此作者也照葫芦画瓢编写了config.py文件(适配的话一般只改Base部分)
import argparse
import logging
import os
import random
import sys
import time
from datetime import datetime
import torch
def get_config():
parser = argparse.ArgumentParser()
'''Base'''
parser.add_argument('--num_classes', type=int, default=2)
parser.add_argument('--model_name', type=str, default='bert',
choices=['bert', 'roberta'])
parser.add_argument('--method_name', type=str, default='fnn',
choices=['gru', 'rnn', 'bilstm', 'lstm', 'fnn', 'textcnn', 'attention', 'lstm+textcnn',
'lstm_textcnn_attention'])
'''Optimization'''
parser.add_argument('--train_batch_size', type=int, default=4)
parser.add_argument('--test_batch_size', type=int, default=16)
parser.add_argument('--num_epoch', type=int, default=50)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--weight_decay', type=float, default=0.01)
'''Environment'''
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--backend', default=False, action='store_true')
parser.add_argument('--workers', type=int, default=0)
parser.add_argument('--timestamp', type=int, default='{:.0f}{:03}'.format(time.time(), random.randint(0, 999)))
args = parser.parse_args()
args.device = torch.device(args.device)
'''logger'''
args.log_name = '{}_{}_{}.log'.format(args.model_name, args.method_name,
datetime.now().strftime('%Y-%m-%d_%H-%M-%S')[2:])
if not os.path.exists('logs'):
os.mkdir('logs')
logger = logging.getLogger()
logger.setLevel(logging.INFO)...