内容简介:01—
01
—
背景
在之前的文章中介绍了多分类逻辑回归算法的数据原理,参考文章链接
-
CSDN文章链接:
https://blog.csdn.net/Gamer_gyt/article/details/85209496
-
公众号:
该篇文章介绍一下Spark中多分类算法,主要包括的技术点如下:
-
多分类实现方式
一对一 (One V One)
一对其余(One V Remaining)
多对多 (More V More)
-
Spark中的多分类实现
02
—
多分类实现方式
一对一
假设某个分类中有N个类别,将这N个类别两两配对(继而转化为二分类问题),这样可以得到 N(N-1)/ 2个二分类器,这样训练模型时需要训练 N(N-1)/ 2个模型,预测时将样本输送到这些模型中,最终统计出现次数较多的类别结果作为最终类别。
假设现在有三个类别:类别A,类别B,类别C,类别D。一对一实现多分类如下图所示:
一对多
一对多,即每次把一个类别看做是正类,其余看做负类,此时假设有N个类别,则对应N个分类器,预测时
-
若只有一个分类器将样本预测为正类,则结果为正类
-
若只有一个分类器将样本预测为负类,则结果为负类
-
若预测结果有正类或者负类个数不唯一,则根据概率最大对应的结果作为最终结果
一对多实现多分类如下图所示:
多对多
将多个类别作为正类,将多个类别作为负类。显然正反类构造必须有特殊的设计,不能随意选取,在周志华老师的西瓜书中提到了“纠错输出码技术(EOOC)”
EOOC工作主要分为两步:
-
编码:对N个类别分别做M次划分,每次划分将一部分类别划为正类,一部分划分为负类,从而形成一个二分类分类器,这样一共产生M个训练集,训练出M个分类器
-
解码:M个分类器分别对测试样本进行预测,这些预测标记组成一个编码,将这个预测编码分别与每个类别各自的编码进行比较,返回其中距离最小的类别作为最终预测结果
上图(a)中,C1类别经过5个分类器后得到的编码为[-1,+1,-1,+1,+1],测试示例经过5个分类器后的编码为[-1,-1,+1,-1,+1],两个编码对比,有三个对应位置不一样,所以海明距离为3,同理可求得测试样例与其他类别的海明距离和欧式距离。
上图(b)中,比图(a)多了0类,即停用类。在计算海明距离时,停用类和测试示例的距离为0.5(笔者认为这里的参数可以进行动态的调整),欧式距离就是正常的当做0值操作。
03
—
OneVsRest介绍
OneVsRest将一个给定的二分类算法有效地扩展到多分类问题应用中,也叫做”One-vs-All”算法。OneVsRest是一个Estimator(评估器)。它采用一个基础的Classifier然后对于k个类别分别创建二分类问题。类别i的二分类分类器用来预测类别为i还是不为i,即将i类和其他类别区分开来。最后,通过依次对k个二分类分类器进行评估,取置信最高的分类器的标签作为i类别的标签。
对应多 “分类实现方式” 中的一对多
04
—
Spark中的多分类实现
基于ml包中的LogisticRegression实现
import org.apache.spark.sql.SparkSession import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator object MultiClassLR { def main(args: Array[String]): Unit = { val input_data = "data/sample_multiclass_classification_data.txt" // args(0) val spark = SparkSession.builder.master("local[5]").appName("MulticlassLRWithElasticNetExample") .getOrCreate() runBaseLR(spark,input_data) runBaseOneVsRest(spark,input_data) spark.stop() } def runBaseLR(spark: SparkSession, input_data: String): Unit = { // 加载训练数据集 val split = spark.read.format("libsvm").load(input_data).randomSplit(Array(1,1)) val train_data = split(0) val test_data = split(1) // 创建模型 val lr = new LogisticRegression().setMaxIter(20).setRegParam(0.3).setElasticNetParam(0.8) // 训练模型 val model = lr.fit(train_data) // 系数矩阵、截距向量 println(s"coefficientMatrix is: \n ${model.coefficientMatrix}") println(s"interceptVector is: \n ${model.interceptVector}") // 测试集计算 val predictions = model.transform(test_data) val test_count = test_data.count().toInt predictions.take(test_count).foreach(println) val evaluator = new MulticlassClassificationEvaluator()//.setLabelCol("label").setPredictionCol("prediction") val accuracy =evaluator.setMetricName("accuracy").evaluate(predictions); val weightedPrecision=evaluator.setMetricName("weightedPrecision").evaluate(predictions); val weightedRecall=evaluator.setMetricName("weightedRecall").evaluate(predictions); val f1=evaluator.setMetricName("f1").evaluate(predictions); println(s"accuracy is $accuracy") println(s"weightedPrecision is $weightedPrecision") println(s"weightedRecall is $weightedRecall") println(s"f1 is $f1") } def runBaseOneVsRest(spark: SparkSession, input_data: String): Unit = { // 加载训练数据集 val split = spark.read.format("libsvm").load(input_data).randomSplit(Array(1,1)) val train_data = split(0) val test_data = split(1) // 创建模型 val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8) // 训练模型 val model = new OneVsRest().setClassifier(lr).fit(train_data) // 测试集计算 val predictions = model.transform(test_data) val test_count = test_data.count().toInt predictions.take(test_count).foreach(println) val evaluator = new MulticlassClassificationEvaluator()//.setLabelCol("label").setPredictionCol("prediction") val accuracy =evaluator.setMetricName("accuracy").evaluate(predictions); val weightedPrecision=evaluator.setMetricName("weightedPrecision").evaluate(predictions); val weightedRecall=evaluator.setMetricName("weightedRecall").evaluate(predictions); val f1=evaluator.setMetricName("f1").evaluate(predictions); println(s"accuracy is $accuracy") println(s"weightedPrecision is $weightedPrecision") println(s"weightedRecall is $weightedRecall") println(s"f1 is $f1") } }
运行输出信息(runBaseLR)
[0.0,(4,[0,1,2,3],[-0.666667,-0.583333,0.186441,0.333333]),[0.142419333934195,-0.3772619583435227,0.06140018515891296],[0.3973163522554911,0.23628803028122905,0.3663956174632798],0.0] [0.0,(4,[0,1,2,3],[-0.277778,-0.333333,0.322034,0.583333]),[0.21570018327413792,-0.5776524462730686,0.06140018515891296],[0.433024032357093,0.19586792869116354,0.3711080389517435],0.0] [0.0,(4,[0,1,2,3],[-0.222222,-0.583333,0.355932,0.583333]),[0.21570018327413792,-0.6014274236583748,0.06140018515891296],[0.43502594980545906,0.19215033820182928,0.3728237119927117],0.0] [0.0,(4,[0,1,2,3],[-0.166667,-0.416667,0.38983,0.5]),[0.19127333120195608,-0.5901059157567916,0.06140018515891296],[0.42808566999671693,0.19596657170058288,0.37594775830270016],0.0] ... accuracy is 0.8615384615384616 weightedPrecision is 0.9017369727047146 weightedRecall is 0.8615384615384616 f1 is 0.8554924320962056 ...
运行输出信息(runBaseOneVsRest)
[0.0,(4,[0,1,2,3],[-0.666667,-0.583333,0.186441,0.333333]),0.0] [0.0,(4,[0,1,2,3],[-0.222222,-0.583333,0.355932,0.583333]),0.0] [0.0,(4,[0,1,2,3],[-0.111111,-0.166667,0.38983,0.416667]),0.0] ... accuracy is 0.8051948051948052 weightedPrecision is 0.8539693389317449 weightedRecall is 0.8051948051948052 f1 is 0.7797931797931799 ...
以上所述就是小编给大家介绍的《多分类实现方式介绍和在 Spark 上实现多分类逻辑回归》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:- 多分类逻辑回归 (Multinomial Logistic Regression)
- TensorFlow 实现 Mnist 数据集的多分类逻辑回归模型
- centos创建逻辑卷和扩容逻辑卷
- AI「王道」逻辑编程的复兴?清华提出神经逻辑机,已入选ICLR
- 内聚代码提高逻辑可读性,用MCVP接续你的大逻辑
- 逻辑式编程语言极简实现(使用C#) - 1. 逻辑式编程语言介绍
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。