sklearn 快速入门

栏目: Python · 发布时间: 8年前

内容简介:sklearn 快速入门

简介

sklearn自带了一些标准数据集,用于分类问题的 irisdigits 。用于回归问题的 boston房价 数据集。

导入数据集

from sklearn import datasets

自带的数据都放在datasets里面

iris = datasets.load_iris()
digits = datasets.load_digits()

datasets 是dict类型的对象,包含数据和元数据信息。数据放在.data里,标签放在.target里。

type(iris.data)
numpy.ndarray

.data里放的是特征的信息

print "iris.data.dtype: ",iris.data.dtype
print "iris.data.shape: ",iris.data.shape
print "iris.data.ndim: ",iris.data.ndim
print "--------------------------------"
print iris.data[0:5]
iris.data.dtype:  float64
iris.data.shape:  (150, 4)
iris.data.ndim:  2
--------------------------------
[[ 5.1  3.5  1.4  0.2]
 [ 4.9  3.   1.4  0.2]
 [ 4.7  3.2  1.3  0.2]
 [ 4.6  3.1  1.5  0.2]
 [ 5.   3.6  1.4  0.2]]

.target里放的是标签信息

print "iris.target.dtype: ",iris.target.dtype
print "iris.target.shape: ",iris.target.shape
print "iris.target.ndim: ",iris.target.ndim
print "--------------------------------"
print iris.target[0:5]
iris.target.dtype:  int64
iris.target.shape:  (150,)
iris.target.ndim:  1
--------------------------------
[0 0 0 0 0]
type(digits)
sklearn.datasets.base.Bunch
print "digits.data.dtype: ",digits.data.dtype
print "digits.data.shape: ",digits.data.shape
print "digits.data.ndim: ",digits.data.ndim
print "--------------------------------"
print digits.data[0:5]
digits.data.dtype:  float64
digits.data.shape:  (1797, 64)
digits.data.ndim:  2
--------------------------------
[[  0.   0.   5.  13.   9.   1.   0.   0.   0.   0.  13.  15.  10.  15.
    5.   0.   0.   3.  15.   2.   0.  11.   8.   0.   0.   4.  12.   0.
    0.   8.   8.   0.   0.   5.   8.   0.   0.   9.   8.   0.   0.   4.
   11.   0.   1.  12.   7.   0.   0.   2.  14.   5.  10.  12.   0.   0.
    0.   0.   6.  13.  10.   0.   0.   0.]
 [  0.   0.   0.  12.  13.   5.   0.   0.   0.   0.   0.  11.  16.   9.
    0.   0.   0.   0.   3.  15.  16.   6.   0.   0.   0.   7.  15.  16.
   16.   2.   0.   0.   0.   0.   1.  16.  16.   3.   0.   0.   0.   0.
    1.  16.  16.   6.   0.   0.   0.   0.   1.  16.  16.   6.   0.   0.
    0.   0.   0.  11.  16.  10.   0.   0.]
 [  0.   0.   0.   4.  15.  12.   0.   0.   0.   0.   3.  16.  15.  14.
    0.   0.   0.   0.   8.  13.   8.  16.   0.   0.   0.   0.   1.   6.
   15.  11.   0.   0.   0.   1.   8.  13.  15.   1.   0.   0.   0.   9.
   16.  16.   5.   0.   0.   0.   0.   3.  13.  16.  16.  11.   5.   0.
    0.   0.   0.   3.  11.  16.   9.   0.]
 [  0.   0.   7.  15.  13.   1.   0.   0.   0.   8.  13.   6.  15.   4.
    0.   0.   0.   2.   1.  13.  13.   0.   0.   0.   0.   0.   2.  15.
   11.   1.   0.   0.   0.   0.   0.   1.  12.  12.   1.   0.   0.   0.
    0.   0.   1.  10.   8.   0.   0.   0.   8.   4.   5.  14.   9.   0.
    0.   0.   7.  13.  13.   9.   0.   0.]
 [  0.   0.   0.   1.  11.   0.   0.   0.   0.   0.   0.   7.   8.   0.
    0.   0.   0.   0.   1.  13.   6.   2.   2.   0.   0.   0.   7.  15.
    0.   9.   8.   0.   0.   5.  16.  10.   0.  16.   6.   0.   0.   4.
   15.  16.  13.  16.   1.   0.   0.   0.   0.   3.  15.  10.   0.   0.
    0.   0.   0.   2.  16.   4.   0.   0.]]
print "digits.target.dtype: ",digits.target.dtype
print "digits.target.shape: ",digits.target.shape
print "digits.target.ndim: ",digits.target.ndim
print "--------------------------------"
print digits.target[0:5]
digits.target.dtype:  int64
digits.target.shape:  (1797,)
digits.target.ndim:  1
--------------------------------
[0 1 2 3 4]

digits是手写字数据集,可以通过images选择加载8*8的矩阵图片

digits.images[1]
array([[  0.,   0.,   0.,  12.,  13.,   5.,   0.,   0.],
       [  0.,   0.,   0.,  11.,  16.,   9.,   0.,   0.],
       [  0.,   0.,   3.,  15.,  16.,   6.,   0.,   0.],
       [  0.,   7.,  15.,  16.,  16.,   2.,   0.,   0.],
       [  0.,   0.,   1.,  16.,  16.,   3.,   0.,   0.],
       [  0.,   0.,   1.,  16.,  16.,   6.,   0.,   0.],
       [  0.,   0.,   1.,  16.,  16.,   6.,   0.,   0.],
       [  0.,   0.,   0.,  11.,  16.,  10.,   0.,   0.]])

学习和预测

在scikit-learn里面,一个分类模型有两个主要的方法:fit(X,y)和predict(T)

这里我们用svm做例子,看怎么使用。

from sklearn import svm
clf = svm.SVC(gamma=0.001,C=100.)

选择模型的参数在我们这个例子里面,我们使用手工设置参数,此外还可以使用 网格搜索(grid search) 交叉验证(cross validation) 来选择参数.

现在我们的模型就是 clf。它是一个分类器。现在让模型可以进行分类任务,先要让模型学习。这里就是把训练数据集放到fit函数里,这么把digits数据集最后一个记录当作test dataset,前面1796个样本当作training dataset

clf.fit(digits.data[:-1],digits.target[:-1])
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

现在用学习好的模型预测最后一个样本的标签

print "prediction: ", clf.predict(digits.data[-1:])
print "actual: ",digits.target[-1:]
prediction:  [8]
actual:  [8]

保存模型

通过pickle来保存模型

from sklearn import svm
from sklearn import datasets
clf = svm.SVC()
iris = datasets.load_iris()
X, y = iris.data,iris.target
clf.fit(X,y)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

保存上面的模型

import pickle
s = pickle.dumps(clf)

读取保存的模型

clf2 = pickle.loads(s)
print "prediction: ",clf2.predict(X[0:1])
print "actual: ",y[0:1]
prediction:  [0]
actual:  [0]

此外,可以使用joblib代替pickle(joblib.dump & joblib.load)。joblib对大的数据很有效,但是只能保存的硬盘,而不是一个string对象里。

用joblib保存模型

from sklearn.externals import joblib
joblib.dump(clf,"filename.pkl")
['filename.pkl',
 'filename.pkl_01.npy',
 'filename.pkl_02.npy',
 'filename.pkl_03.npy',
 'filename.pkl_04.npy',
 'filename.pkl_05.npy',
 'filename.pkl_06.npy',
 'filename.pkl_07.npy',
 'filename.pkl_08.npy',
 'filename.pkl_09.npy',
 'filename.pkl_10.npy',
 'filename.pkl_11.npy']

读取joblib保存的模型

clf3 = joblib.load("filename.pkl")
print "prediction: ",clf3.predict(X[0:1])
print "actual: ",y[0:1]
prediction:  [0]
actual:  [0]

注意:

joblib返回一系列的文件名,是因为模型里面的每一个numpy矩阵都保存在独立的文件里,并且要在相同的路径下面,再次读取的时候才能成功。

协议

sklearn 有如下几点规则,保证其能正常工作。

类型转换

除非特别指定,否则都会自动转换到 float64

import numpy as np
from sklearn import random_projection

rng = np.random.RandomState(0)
X = rng.rand(10,2000)
X = np.array(X,dtype='float32')
X.dtype
dtype('float32')
transformer = random_projection.GaussianRandomProjection()
X_new = transformer.fit_transform(X)
X_new.dtype
dtype('float64')

X本来是float32类型,通过fit_transform(X)转换到float64

回归的结果被转换成float64,分类的数据类型不变。

from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
clf = SVC()
# 回归
clf.fit(iris.data, iris.target)  
print u"回归结果:",list(clf.predict(iris.data[:3]))
# 分类
clf.fit(iris.data, iris.target_names[iris.target]) 
print u"分类结果:",list(clf.predict(iris.data[:3]))
回归结果: [0, 0, 0]
分类结果: ['setosa', 'setosa', 'setosa']

回归用的是iris.target,分类用的是iris.target_names

重新训练和更新超参数

模型的超参数在模型训练完成以后仍然可以更新,通过 sklearn.pipeline.Pipeline.set_params 方法。多次调用fit会覆盖前面训练的模型。

import numpy as np
from sklearn.svm import SVC

rng = np.random.RandomState(0)
X = rng.rand(100, 10)
y = rng.binomial(1, 0.5, 100)
X_test = rng.rand(5, 10)
clf = SVC()
clf.set_params(kernel="linear").fit(X,y)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='linear',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
clf.predict(X_test)
array([1, 0, 1, 1, 0])
clf.set_params(kernel='rbf').fit(X, y)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
clf.predict(X_test)
array([0, 0, 0, 1, 0])

本文永久更新链接地址 http://www.linuxidc.com/Linux/2017-06/144940.htm


以上所述就是小编给大家介绍的《sklearn 快速入门》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Ajax Design Patterns

Ajax Design Patterns

Michael Mahemoff / O'Reilly Media / 2006-06-29 / USD 44.99

Ajax, or Asynchronous JavaScript and XML, exploded onto the scene in the spring of 2005 and remains the hottest story among web developers. With its rich combination of technologies, Ajax provides a s......一起来看看 《Ajax Design Patterns》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

JS 压缩/解压工具
JS 压缩/解压工具

在线压缩/解压 JS 代码

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具