本帖介绍scikit-learn广义线性模型的例子。
# -*- coding: utf-8 -*-
"""
2018.07.02
by pengxw
"""
# linear model example
#
http://scikit-learn.org/stable/a ... r-model-plot-ols-py
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets, linear_model
from sklearn.metrics import mean_squared_error, r2_score
# Load the diabetes dataset
diabetes = datasets.load_diabetes()
# Use only one feature.
# 函数newaxis使得从多维数组取出一列时仍为一列,而非一行
diabetes_X = diabetes.data[:, np.newaxis, 2]
# Split the data into training/testing sets
diabetes_X_train = diabetes_X[:-20]
diabetes_X_test = diabetes_X[-20:]
# Split the targets into training/testing sets
diabetes_y_train = diabetes.target[:-20]
diabetes_y_test = diabetes.target[-20:]
# Create linear regression object
regr = linear_model.LinearRegression()
# Train the model using the training sets
regr.fit(diabetes_X_train, diabetes_y_train)
# Make predictions using the testing set
diabetes_y_pred = regr.predict(diabetes_X_test)
# The coefficients
print('Coefficients: \n', regr.coef_)
# The mean squared error
print("Mean squared error: %.2f"
% mean_squared_error(diabetes_y_test, diabetes_y_pred))
# Explained variance score: 1 is perfect prediction
print('Variance score: %.2f' % r2_score(diabetes_y_test, diabetes_y_pred))
# Plot outputs
plt.scatter(diabetes_X_test, diabetes_y_test, color='black')
plt.plot(diabetes_X_test, diabetes_y_pred, color='blue', linewidth=3)
# 不输出刻度
#plt.xticks(())
#plt.yticks(())
plt.show()
# Ridge regresion example
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model
# x is the 10*10 Hilbert matrix
# 实质就是把np.arange(0, 10)变为数组列,从而产生10*1*1*10的矩阵
x = 1. / (np.arange(1, 11) + np.arange(0, 10)[:, np.newaxis])
# 创建全部取值为1的数组
y = np.ones(10)
# 初始化岭回归迭代次数
n_alphas = 200
# 初始化岭回归参数,以10为底,指数取-10到2,即10^(-10)到10^2
alphas = np.logspace(-10, 2, n_alphas, base = 10)
coefs = []
for a in alphas:
ridge = linear_model.Ridge(alpha=a, fit_intercept=False)
ridge.fit(x, y)
coefs.append(ridge.coef_)
# display results
ax = plt.gca()
ax.plot(alphas, coefs)
ax.set_xscale('log') # 对数化x轴
ax.set_xlim(ax.get_xlim()[::-1]) # reverse axis
plt.xlabel('alpha')
plt.ylabel('weights')
plt.title('Ridge coefficients as a function of the regularization')
plt.axis('tight') # 修改x、y坐标的范围让所有的数据显示出来
plt.show()