|-转 [NLP实践01]simpletransformers安装和文本分类简单实现
卡在了conda install pytorch>=1.6 cudatoolkit=11.0 -c pytorch,运行代码一直不动
快速安装 simpletransformers
simpletransformers 项目地址:hub.fastgit.org/ThilinaRaja…
simpletransformers 文档地址:
simpletransformers.ai/
快速安装方式:
使用Conda安装
1)新建虚拟环境
conda create -n st python pandas tqdm
conda activate st
2)安装cuda环境
conda install pytorch>=1.6 cudatoolkit=11.0 -c pytorch
3)安装 simpletransformers
pip install simpletransformers
安装 wandb
wandb 用于在web浏览器中追踪和可视化Weights和Biases(wandb)
复制代码pip install wandb
目前支持的任务:
任务模型二元和多类文本分类ClassificationModel对话式人工智能(聊天机器人训练)ConvAIModel语言生成LanguageGenerationModel语言模型训练/微调LanguageModelingModel多标签文本分类MultiLabelClassificationModel多模态分类(文本和图像数据结合)MultiModalClassificationModel命名实体识别NERModel问答QuestionAnsweringModel回归ClassificationModel句子对分类ClassificationModel文本表示生成RepresentationModel
预训练模型去哪里下载?
有关预训练模型,请参阅Hugging Face 文档。
根据文档中给出的model_type,只要在args中正确设置model_name的字典值就是可以加载预训练模型
【实践01】文本分类
数据集
笔者选用CLUE的作为benchmark数据集
选取数据集:IF***TEK' 长文本分类
中文语言理解测评基准(CLUE)
www.cluebenchmarks.com/dataSet_sea…
为更好的服务中文语言理解、任务和产业界,做为通用语言模型测评的补充,通过搜集整理发布中文任务及标准化测评等方式完善基础设施,最终促进中文NLP的发展。
Update: CLUE论文被计算语言学国际会议 COLING2020高分录用
IF***TEK' 长文本分类
下载地址:github.com/CLUEbenchma…
该数据集共有1.7万多条关于app应用描述的长文本标注数据,包含和日常生活相关的各类应用主题,共119个类别:"打车":0,"地图导航":1,"免费WIFI":2,"租车":3,….,"女性":115,"经营":116,"收款":117,"其他":118(分别用0-118表示)。每一条数据有三个属性,从前往后分别是 类别ID,类别名称,文本内容。
数据量:训练集(12,133),验证集(2,599),测试集(2,600)
css复制代码{"label": "110",
"label_des": "社区超市",
"sentence": "朴朴快送超市创立于2016年,专注于打造移动端30分钟即时配送一站式购物平台,商品品类包含水果、蔬菜、肉禽蛋奶、海鲜水产、粮油调味、酒水饮料、休闲食品、日用品、外卖等。朴朴公司希望能以全新的商业模式,更高效快捷的仓储配送模式,致力于成为更快、更好、更多、更省的在线零售平台,带给消费者更好的消费体验,同时推动中国食品安全进程,成为一家让社会尊敬的互联网公司。,朴朴一下,又好又快,1.配送时间提示更加清晰友好2.保障用户隐私的一些优化3.其他提高使用体验的调整4.修复了一些已知bug"}
数据处理
Simple Transformers要求数据必须包含在至少两列的Pandas DataFrames中。 只需为列的文本和标签命名,SimpleTransformers就会处理数据。
第一列包含文本,类型为str。
第二列包含标签,类型为int。
对于多类分类,标签应该是从0开始的整数。
ini复制代码import json
import pandas as pd
def load_clue_iflytek(path,mode=None):
"""适应simpletransformer的加载方式"""
data = []
with open(path, "r", encoding="utf-8") as fp:
if mode == 'train' or mode =='dev':
for idx, line in enumerate(fp):
line = json.loads(line.strip())
label = int(line["label"])
text = line['sentence']
data.append([text, label])
data_df = pd.DataFrame(data, columns=["text", "labels"])
return data_df
elif mode == 'test':
for idx, line in enumerate(fp):
line = json.loads(line.strip())
text = line['sentence']
data.append([text])
data_df = pd.DataFrame(data, columns=["text"])
return data_df
模型搭建和训练
先进行参数配置,Simple Transformers具有dict args, 有关每个args的详细说明,可有参考:simpletransformers.ai/docs/tips-a…
1)参数配置
ini复制代码# 配置config
import argparse
def data_config(parser):
parser.add_argument("--trainset_path", type=str, default="data/Chinese_Spam_Message/train.json",
help="训练集路径")
parser.add_argument("--testset_path", type=str, default="data/Chinese_Spam_Message/test.txt",
help="测试集路径")...