内容简介:01—
01
—
背景
在之前的文章中介绍了多分类逻辑回归算法的数据原理,参考文章链接
-
CSDN文章链接:
https://blog.csdn.net/Gamer_gyt/article/details/85209496
-
公众号:
该篇文章介绍一下Spark中多分类算法,主要包括的技术点如下:
-
多分类实现方式
一对一 (One V One)
一对其余(One V Remaining)
多对多 (More V More)
-
Spark中的多分类实现
02
—
多分类实现方式
一对一
假设某个分类中有N个类别,将这N个类别两两配对(继而转化为二分类问题),这样可以得到 N(N-1)/ 2个二分类器,这样训练模型时需要训练 N(N-1)/ 2个模型,预测时将样本输送到这些模型中,最终统计出现次数较多的类别结果作为最终类别。
假设现在有三个类别:类别A,类别B,类别C,类别D。一对一实现多分类如下图所示:
一对多
一对多,即每次把一个类别看做是正类,其余看做负类,此时假设有N个类别,则对应N个分类器,预测时
-
若只有一个分类器将样本预测为正类,则结果为正类
-
若只有一个分类器将样本预测为负类,则结果为负类
-
若预测结果有正类或者负类个数不唯一,则根据概率最大对应的结果作为最终结果
一对多实现多分类如下图所示:
多对多
将多个类别作为正类,将多个类别作为负类。显然正反类构造必须有特殊的设计,不能随意选取,在周志华老师的西瓜书中提到了“纠错输出码技术(EOOC)”
EOOC工作主要分为两步:
-
编码:对N个类别分别做M次划分,每次划分将一部分类别划为正类,一部分划分为负类,从而形成一个二分类分类器,这样一共产生M个训练集,训练出M个分类器
-
解码:M个分类器分别对测试样本进行预测,这些预测标记组成一个编码,将这个预测编码分别与每个类别各自的编码进行比较,返回其中距离最小的类别作为最终预测结果
上图(a)中,C1类别经过5个分类器后得到的编码为[-1,+1,-1,+1,+1],测试示例经过5个分类器后的编码为[-1,-1,+1,-1,+1],两个编码对比,有三个对应位置不一样,所以海明距离为3,同理可求得测试样例与其他类别的海明距离和欧式距离。
上图(b)中,比图(a)多了0类,即停用类。在计算海明距离时,停用类和测试示例的距离为0.5(笔者认为这里的参数可以进行动态的调整),欧式距离就是正常的当做0值操作。
03
—
OneVsRest介绍
OneVsRest将一个给定的二分类算法有效地扩展到多分类问题应用中,也叫做”One-vs-All”算法。OneVsRest是一个Estimator(评估器)。它采用一个基础的Classifier然后对于k个类别分别创建二分类问题。类别i的二分类分类器用来预测类别为i还是不为i,即将i类和其他类别区分开来。最后,通过依次对k个二分类分类器进行评估,取置信最高的分类器的标签作为i类别的标签。
对应多 “分类实现方式” 中的一对多
04
—
Spark中的多分类实现
基于ml包中的LogisticRegression实现
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
object MultiClassLR {
def main(args: Array[String]): Unit = {
val input_data = "data/sample_multiclass_classification_data.txt" // args(0)
val spark = SparkSession.builder.master("local[5]").appName("MulticlassLRWithElasticNetExample")
.getOrCreate()
runBaseLR(spark,input_data)
runBaseOneVsRest(spark,input_data)
spark.stop()
}
def runBaseLR(spark: SparkSession, input_data: String): Unit = {
// 加载训练数据集
val split = spark.read.format("libsvm").load(input_data).randomSplit(Array(1,1))
val train_data = split(0)
val test_data = split(1)
// 创建模型
val lr = new LogisticRegression().setMaxIter(20).setRegParam(0.3).setElasticNetParam(0.8)
// 训练模型
val model = lr.fit(train_data)
// 系数矩阵、截距向量
println(s"coefficientMatrix is: \n ${model.coefficientMatrix}")
println(s"interceptVector is: \n ${model.interceptVector}")
// 测试集计算
val predictions = model.transform(test_data)
val test_count = test_data.count().toInt
predictions.take(test_count).foreach(println)
val evaluator = new MulticlassClassificationEvaluator()//.setLabelCol("label").setPredictionCol("prediction")
val accuracy =evaluator.setMetricName("accuracy").evaluate(predictions);
val weightedPrecision=evaluator.setMetricName("weightedPrecision").evaluate(predictions);
val weightedRecall=evaluator.setMetricName("weightedRecall").evaluate(predictions);
val f1=evaluator.setMetricName("f1").evaluate(predictions);
println(s"accuracy is $accuracy")
println(s"weightedPrecision is $weightedPrecision")
println(s"weightedRecall is $weightedRecall")
println(s"f1 is $f1")
}
def runBaseOneVsRest(spark: SparkSession, input_data: String): Unit = {
// 加载训练数据集
val split = spark.read.format("libsvm").load(input_data).randomSplit(Array(1,1))
val train_data = split(0)
val test_data = split(1)
// 创建模型
val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
// 训练模型
val model = new OneVsRest().setClassifier(lr).fit(train_data)
// 测试集计算
val predictions = model.transform(test_data)
val test_count = test_data.count().toInt
predictions.take(test_count).foreach(println)
val evaluator = new MulticlassClassificationEvaluator()//.setLabelCol("label").setPredictionCol("prediction")
val accuracy =evaluator.setMetricName("accuracy").evaluate(predictions);
val weightedPrecision=evaluator.setMetricName("weightedPrecision").evaluate(predictions);
val weightedRecall=evaluator.setMetricName("weightedRecall").evaluate(predictions);
val f1=evaluator.setMetricName("f1").evaluate(predictions);
println(s"accuracy is $accuracy")
println(s"weightedPrecision is $weightedPrecision")
println(s"weightedRecall is $weightedRecall")
println(s"f1 is $f1")
}
}
运行输出信息(runBaseLR)
[0.0,(4,[0,1,2,3],[-0.666667,-0.583333,0.186441,0.333333]),[0.142419333934195,-0.3772619583435227,0.06140018515891296],[0.3973163522554911,0.23628803028122905,0.3663956174632798],0.0] [0.0,(4,[0,1,2,3],[-0.277778,-0.333333,0.322034,0.583333]),[0.21570018327413792,-0.5776524462730686,0.06140018515891296],[0.433024032357093,0.19586792869116354,0.3711080389517435],0.0] [0.0,(4,[0,1,2,3],[-0.222222,-0.583333,0.355932,0.583333]),[0.21570018327413792,-0.6014274236583748,0.06140018515891296],[0.43502594980545906,0.19215033820182928,0.3728237119927117],0.0] [0.0,(4,[0,1,2,3],[-0.166667,-0.416667,0.38983,0.5]),[0.19127333120195608,-0.5901059157567916,0.06140018515891296],[0.42808566999671693,0.19596657170058288,0.37594775830270016],0.0] ... accuracy is 0.8615384615384616 weightedPrecision is 0.9017369727047146 weightedRecall is 0.8615384615384616 f1 is 0.8554924320962056 ...
运行输出信息(runBaseOneVsRest)
[0.0,(4,[0,1,2,3],[-0.666667,-0.583333,0.186441,0.333333]),0.0] [0.0,(4,[0,1,2,3],[-0.222222,-0.583333,0.355932,0.583333]),0.0] [0.0,(4,[0,1,2,3],[-0.111111,-0.166667,0.38983,0.416667]),0.0] ... accuracy is 0.8051948051948052 weightedPrecision is 0.8539693389317449 weightedRecall is 0.8051948051948052 f1 is 0.7797931797931799 ...
以上所述就是小编给大家介绍的《多分类实现方式介绍和在 Spark 上实现多分类逻辑回归》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:- 多分类逻辑回归 (Multinomial Logistic Regression)
- TensorFlow 实现 Mnist 数据集的多分类逻辑回归模型
- centos创建逻辑卷和扩容逻辑卷
- AI「王道」逻辑编程的复兴?清华提出神经逻辑机,已入选ICLR
- 内聚代码提高逻辑可读性,用MCVP接续你的大逻辑
- 逻辑式编程语言极简实现(使用C#) - 1. 逻辑式编程语言介绍
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
Pattern Recognition and Machine Learning
Christopher Bishop / Springer / 2007-10-1 / USD 94.95
The dramatic growth in practical applications for machine learning over the last ten years has been accompanied by many important developments in the underlying algorithms and techniques. For example,......一起来看看 《Pattern Recognition and Machine Learning》 这本书的介绍吧!