统计学习方法-k近邻(KNN)笔记

栏目: 数据库 · 发布时间: 5年前

内容简介:k近邻(KNN)算法是一种简单易于实现的监督机器学习方法,可用于解决分类和回归问题(这里取决于KNN算法最后返回值的处理方法)k近邻算法没有进行数据的训练,直接使用实例数据与训练数据进行比较(首先确认输入实例点的k个邻训练实例点,然后利用这k个训练实例点的类的多数来预测输入实例点的类(通常会将k设为奇数,以便有一个决胜局),因此K-近邻算法不具有显式的学习过程k近邻三要素:距离度量、k值的选择和分类决策规则,三者一旦决定,其KNN算法的结果也就是唯一的了:

k近邻(KNN)算法是一种简单易于实现的监督机器学习方法,可用于解决分类和回归问题(这里取决于KNN算法最后返回值的处理方法)

  • 分类问题:返回值为k个训练标签中占大多数的类
  • 回归问题:返回值为k个训练标签对应值的平均值

k近邻算法没有进行数据的训练,直接使用实例数据与训练数据进行比较(首先确认输入实例点的k个邻训练实例点,然后利用这k个训练实例点的类的多数来预测输入实例点的类(通常会将k设为奇数,以便有一个决胜局),因此K-近邻算法不具有显式的学习过程

k近邻三要素:距离度量、k值的选择和分类决策规则,三者一旦决定,其KNN算法的结果也就是唯一的了:

  • 常用的距离度量是欧式距离(或者更一般的L p 距离,当p=1时则是曼哈顿距离)
  • k值的选择反映了偏差和方差的权衡:k值越小,则模型复杂,偏差小和方差大(训练误差小,测试误差大),容易出现过拟合;k值越大,则模型简单,偏差大和方差小(训练误差大,测试误差小),容易出现欠拟合;因此一般通过交叉验证来选取较小的最优k值
  • 常用的分类决策规则是多数表决,对应经验风险最小化(ERM,经验风险最小化模型认为经验风险最小的模型就是最优的模型)

有文章 机器学习之KNN邻近算法 总结了KNN算法的优劣:

  • KNN算法与人类日常思考模式很相近,无需估计参数,无需训练,也不会生成最终的分类器。
    适合对稀有事件进行分类。较小数据量的快速分类常常使用KNN算法。 特别适合于多分类问题(multi-modal,对象具有多个类别标签) ,在这一类问题上,KNN算法要优于很多其他的机器学习算法(如SVM)
  • KNN算法是一种消极算法(没有训练的过程,到了要决策的时候才会利用已有的数据进行决策),进行分类时由于要计算每个样本之间的距离,计算开销很大,运行比较慢。当数据量非常大时,不适合用KNN算法。当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。在极端情况下,如果某一类别的样本数目甚至小于K,那么最终的分类结果势必要收到很大的影响。分类结果完全以来于训练样本,不能给出分类规则(可解释性不强,相比决策树模型而言)。

按照<统计学习方法>KNN的3.1算法,也可以称为brute-force(暴力搜索),步骤如下:

  • 确定训练集和k值
  • 循环计算每个训练集中的点与实例点的距离,寻找出离实例点最近的k个点
  • 计算k个点所在类出现次数,根据多数表决的决策规则将实例点进行分类

Python代码如下( brute_force_knn 函数是常规暴力搜索的KNN算法, knn_fd_cv 函数则是KNN的K折交叉验证):

import numpy as np
import pandas as pd
from collections import Counter
from sklearn.model_selection import KFold

class KNN:
    def brute_force_knn(self, dat_X, label, input_x, k):
        # 计算欧式距离
        dist = ((input_x - dat_X)**2).sum(axis=1)**0.5
        dist_index = np.argsort(dist)
        # 分类决策:多数表决
        label_pre = []
        for i in dist_index[:k]:
            label_pre.append(label[i])
        target = Counter(label_pre).most_common(1)[0][0]
        return target

    def knn_fd_cv(self, dat_X, label, k, fd):
        kf = KFold(n_splits=fd)
        error_list = []
        for train_index, test_index in kf.split(dat_X):
            X_train, X_test = dat_X[train_index], dat_X[test_index]
            y_train, y_test = label[train_index], label[test_index]

            error = 0
            for i in range(len(y_test)):
                t = knn.brute_force_knn(dat_X=X_train, label=y_train, input_x=X_test[i], k=k)
                # 计算测试误差
                if t != y_test[i]:
                    error += 1
            # print(error/len(y_test)*100)
            error_list.append(error/len(y_test)*100)
        return np.mean(error_list)

其实sklearn已有调用KNN算法的 sklearn.neighbors.KNeighborsClassifier ,类似于常用的R包,上面的代码只是为了理解下KNN算法和交叉验证而写着看看,sklearn的使用以后再看~

加权KNN(常用有高斯函数、反函数和减法函数等方式),即给更近的邻位点分配更大的权重(可消除孤立实例点的影响):

  • 对于离散型数据,将k个点对应的标记以权重区别对待,权重相加,那个类的值大就属于哪类
  • 对于数值型数据,将k个点求加权平均(累加值与权重相乘的结果,然后除以所有权重之和)

由于KNN是线性扫描的方式,当训练集很大的时候则会非常耗时;这时可以使用kd-树来找寻找最邻位点,其会将实例点储存到可以进行快速搜索的树形数据结构中,使得每次在局部空间中搜索,从而加快搜索速度,具体实现方式可参考<统计学习方法>以及 kNN里面的两种优化的数据结构:kd-tree和ball-tree,在算法实现原理上有什么区别?

以一个Kaggle上的测试数据集 MNIST 为例(数字识别),用上述 Python 的简单KNN算法计算下不同k选值下的识别错误率,进而选择最优的K值:

# 读入测试数据集
dataset = pd.read_csv("train.csv")
dataset = np.array(dataset)
# 确定实例点和其对应标记
train_dat = dataset[:,1:]
train_label = dataset[:,0]

knn = KNN()
# 计算k取值在2-7下的测试误差,交叉验证为10折
for i in range(2,8):
    r = knn.knn_fd_cv(dat_X=train_dat, label=train_label, k=i, fd=10)
    print("k =",i,"error rate =",r)

最后发现k=4时测试误差最低(2.95%,即97.05%的准确率,效果有点差~相比其他用CNN卷积神经网络算法的99.9%的准确率来说)

本文出自于 http://www.bioinfo-scrounger.com 转载请注明出处


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Building Web Reputation Systems

Building Web Reputation Systems

Randy Farmer、Bryce Glass / Yahoo Press / 2010 / GBP 31.99

What do Amazon's product reviews, eBay's feedback score system, Slashdot's Karma System, and Xbox Live's Achievements have in common? They're all examples of successful reputation systems that enable ......一起来看看 《Building Web Reputation Systems》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

CSS 压缩/解压工具
CSS 压缩/解压工具

在线压缩/解压 CSS 代码

MD5 加密
MD5 加密

MD5 加密工具