matplotlibで判別境界を描画する

ちょっとPython+matplotlibで判別境界を描画したくなって、昔やってたやり方を思い出した。
他の方法は調べていない。
次のコードは、線形判別分析と2次判別分析とSVMの判別境界を並べて描画したものである。

●コード
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
from sklearn.svm import SVC

def draw_contour(clf, x1_range, x2_range, ax=None):
    cmap = ListedColormap(['k'])
    if ax is None: ax = plt.gca()

    margin1 = (x1_range[1] - x1_range[0]) * 0.05  # 5% of original range
    margin2 = (x2_range[1] - x2_range[0]) * 0.05
    xx1, xx2 = np.meshgrid(np.arange(x1_range[0] - margin1, x1_range[1] + margin1, 0.01),
                           np.arange(x2_range[0] - margin2, x2_range[1] + margin2, 0.01))

    Z = clf.predict(np.array([xx1.ravel(), xx2.ravel()]).T)
    Z = Z.reshape(xx1.shape)
    ax.contour(xx1, xx2, Z, cmap=cmap)

def draw_meshgrid(clf, x1_range, x2_range, ax=None):
    cmap = ListedColormap(['lightgreen', 'lightpink'])
    if ax is None: ax = plt.gca()

    margin1 = (x1_range[1] - x1_range[0]) * 0.05
    margin2 = (x2_range[1] - x2_range[0]) * 0.05
    xx1, xx2 = np.meshgrid(np.arange(x1_range[0] - margin1, x1_range[1] + margin1, 0.01),
                           np.arange(x2_range[0] - margin2, x2_range[1] + margin2, 0.01))

    Z = clf.predict(np.array([xx1.ravel(), xx2.ravel()]).T)
    Z = Z.reshape(xx1.shape)
    ax.contourf(xx1, xx2, Z, cmap=cmap)

def draw_sample(X, y, ax=None):
    if ax is None: ax = plt.gca()
    ax.scatter(X[y == 0, 0], X[y == 0, 1], color='b', marker='o')
    ax.scatter(X[y == 1, 0], X[y == 1, 1], color='r', marker='x')

# sample data
np.random.seed(0)
X = np.vstack([
    np.random.randn(10, 2) + [1.0, 1.0],  # 10 points around (1.0, 1.0)
    np.random.randn(10, 2) + [3.0, 2.5]   # 10 points around (3.0, 2.5)
])
y = np.array([0] * 10 + [1] * 10)  # class labels

fig, ax = plt.subplots(2, 3, figsize=(10,6))

# 線形判別分析
clf_lda = LinearDiscriminantAnalysis()
clf_lda.fit(X, y)
draw_contour(clf_lda, (X[:,0].min(), X[:,0].max()), (X[:,1].min(), X[:,1].max()), ax=ax[0,0])
draw_sample(X, y, ax=ax[0,0])
draw_meshgrid(clf_lda, (X[:,0].min(), X[:,0].max()), (X[:,1].min(), X[:,1].max()), ax=ax[1,0])
draw_sample(X, y, ax=ax[1,0])
ax[0, 0].set_title("LDA")

# 2次判別分析
clf_qda = QuadraticDiscriminantAnalysis()
clf_qda.fit(X, y)
draw_contour(clf_qda, (X[:,0].min(), X[:,0].max()), (X[:,1].min(), X[:,1].max()), ax=ax[0,1])
draw_sample(X, y, ax=ax[0,1])
draw_meshgrid(clf_qda, (X[:,0].min(), X[:,0].max()), (X[:,1].min(), X[:,1].max()), ax=ax[1,1])
draw_sample(X, y, ax=ax[1,1])
ax[0, 1].set_title("QDA")

# SVM
clf_svm = SVC()
clf_svm.fit(X, y)
draw_contour(clf_svm, (X[:,0].min(), X[:,0].max()), (X[:,1].min(), X[:,1].max()), ax=ax[0,2])
draw_sample(X, y, ax=ax[0,2])
draw_meshgrid(clf_svm, (X[:,0].min(), X[:,0].max()), (X[:,1].min(), X[:,1].max()), ax=ax[1,2])
draw_sample(X, y, ax=ax[1,2])
ax[0, 2].set_title("SVM")

plt.show()

●結果