分享web开发知识

注册/登录|最近发布|今日推荐

主页 IT知识网页技术软件开发前端开发代码编程运营维护技术分享教程案例
当前位置:首页 > 代码编程

芝麻HTTP:记scikit-learn贝叶斯文本分类的坑

发布时间:2023-09-06 01:40责任编辑:胡小海关键词:暂无标签

基本步骤:

1、训练素材分类:

我是参考官方的目录结构:

每个目录中放对应的文本,一个txt文件一篇对应的文章:就像下面这样

需要注意的是所有素材比例请保持在相同的比例(根据训练结果酌情调整、不可比例过于悬殊、容易造成过拟合(通俗点就是大部分文章都给你分到素材最多的那个类别去了))

废话不多说直接上代码吧(测试代码的丑得一逼;将就着看看吧)

需要一个小工具: pip install chinese-tokenizer

这是训练器:

import reimport jiebaimport jsonfrom io import BytesIOfrom chinese_tokenizer.tokenizer import Tokenizerfrom sklearn.datasets import load_filesfrom sklearn.feature_extraction.text import CountVectorizer, TfidfTransformerfrom sklearn.model_selection import train_test_splitfrom sklearn.naive_bayes import MultinomialNBfrom sklearn.externals import joblibjie_ba_tokenizer = Tokenizer().jie_ba_tokenizer# 加载数据集training_data = load_files(‘./data‘, encoding=‘utf-8‘)# x_train txt内容 y_train 是类别(正 负 中 )x_train, _, y_train, _ = train_test_split(training_data.data, training_data.target)print(‘开始建模.....‘)with open(‘training_data.target‘, ‘w‘, encoding=‘utf-8‘) as f: ???f.write(json.dumps(training_data.target_names))# tokenizer参数是用来对文本进行分词的函数(就是上面我们结巴分词)count_vect = CountVectorizer(tokenizer=jieba_tokenizer)tfidf_transformer = TfidfTransformer()X_train_counts = count_vect.fit_transform(x_train)X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)print(‘正在训练分类器.....‘)# 多项式贝叶斯分类器训练clf = MultinomialNB().fit(X_train_tfidf, y_train)# 保存分类器(好在其它程序中使用)joblib.dump(clf, ‘model.pkl‘)# 保存矢量化(坑在这儿!!需要使用和训练器相同的 矢量器 不然会报错!!!!!! 提示 ValueError dimension mismatch··)joblib.dump(count_vect, ‘count_vect‘)print("分类器的相关信息:")print(clf)

下面是是使用训练好的分类器分类文章:

需要分类的文章放在predict_data目录中:照样是一篇文章一个txt文件

# -*- coding: utf-8 -*-# @Time ???: 2017/8/23 18:02# @Author ?: 哎哟卧槽# @Site ???: # @File ???: 贝叶斯分类器.py# @Software: PyCharm import reimport jiebaimport jsonfrom sklearn.datasets import load_filesfrom sklearn.feature_extraction.text import CountVectorizer, TfidfTransformerfrom sklearn.externals import joblib ?# 加载分类器clf = joblib.load(‘model.pkl‘) count_vect = joblib.load(‘count_vect‘)testing_data = load_files(‘./predict_data‘, encoding=‘utf-8‘)target_names = json.loads(open(‘training_data.target‘, ‘r‘, encoding=‘utf-8‘).read())# ????# 字符串处理tfidf_transformer = TfidfTransformer() X_new_counts = count_vect.transform(testing_data.data)X_new_tfidf = tfidf_transformer.fit_transform(X_new_counts)# 进行预测predicted = clf.predict(X_new_tfidf)for title, category in zip(testing_data.filenames, predicted): ???print(‘%r => %s‘ % (title, target_names[category]))

这个样子将训练好的分类器在新的程序中使用时候 就不报错: ValueError dimension mismatch··

芝麻HTTP:记scikit-learn贝叶斯文本分类的坑

原文地址:https://www.cnblogs.com/zhimaruanjian/p/8390666.html

知识推荐

我的编程学习网——分享web前端后端开发技术知识。 垃圾信息处理邮箱 tousu563@163.com 网站地图
icp备案号 闽ICP备2023006418号-8 不良信息举报平台 互联网安全管理备案 Copyright 2023 www.wodecom.cn All Rights Reserved