(.env) boris@boris-All-Series:~/SVMRBF$ pip install -U scikit-learn scipy
(.env) boris@boris-All-Series:~/SVMRBF$ cat plot_rbf_parameters.py
"""
==================
RBF SVM parameters
==================
Этот пример иллюстрирует влияние параметров gamma и C ядра SVM радиальной базисной функции (RBF).
Интуитивно параметр «гамма» определяет, насколько далеко простирается влияние одного обучающего примера, при этом низкие значения означают «далеко», а высокие значения означают «близко». Параметры «гаммы» можно рассматривать как обратную величину радиуса влияния выборок, выбранных моделью в качестве опорных векторов.Параметр <<C>> компенсирует правильную классификацию обучающих примеров максимизацией запаса решающей функции. Для больших значений <<C>> будет принят меньший запас, если решающая функция лучше правильно классифицирует все обучающие точки. Более низкий «C» будет способствовать большему запасу, следовательно, более простой функции принятия решения за счет точности обучения. Другими словами, <<C>> ведет себя как параметр регуляризации в SVM.
Первый график представляет собой визуализацию решающей функции для множества значений параметров в упрощенной задаче классификации, включающей только 2 входных признака и 2 возможных целевых класса (бинарная классификация). Обратите внимание, что такой график невозможно построить для задач с большим количеством функций или целевых классов.
Второй график представляет собой тепловую карту точности перекрестной проверки классификатора в зависимости от <<C>> и ``gamma``. В этом примере мы исследуем относительно большую сетку для иллюстрации. На практике логарифмическая сетка из
:math:`10^{-3}` до :math:`10^3`. Если лучшие параметры лежат на границах сетки, ее можно расширить в этом направлении при последующем поиске. Обратите внимание, что на графике тепловой карты есть специальная цветная полоса со средним значением, близким к значениям баллов наиболее эффективных моделей, чтобы их можно было легко отличить друг от друга в мгновение ока. Поведение модели очень чувствительно к параметру «гамма». Если «гамма» слишком велика, радиус области влияния опорных векторов включает только сам опорный вектор и никакое количество регуляризация с помощью <<C>> сможет предотвратить переобучение.
Когда «гамма» очень мала, модель слишком ограничена и не может отразить сложность или «форму» данных. Область влияния любого выбранного опорного вектора будет включать в себя всю обучающую выборку. Результирующая модель будет вести себя аналогично линейной модели с набором гиперплоскостей, разделяющих центры высокой плотности любой пары двух классов. Для промежуточных значений мы можем видеть на втором графике, что хорошие модели можно найти на диагонали «C» и «gamma». Гладкие модели (более низкие значения «гаммы») можно сделать более сложными, увеличив важность классификации, каждая точка правильна (большие значения «C»), следовательно, диагональ хороших моделей. Наконец, можно также заметить, что для некоторых промежуточных значений «гаммы» мы получаем одинаково эффективные модели, когда «С» становится очень большим. Это говорит о том, что набор опорных векторов больше не меняется. Радиус ядра RBF сам по себе действует как хороший структурный регуляризатор. Дальнейшее увеличение <<C>> не помогает, вероятно, потому, что больше нет точек обучения с нарушением (внутри поля или неправильно классифицированных), или, по крайней мере, не может быть найдено лучшее решение. При равенстве баллов может иметь смысл использовать меньшую букву «C» значений, так как очень высокие значения «C» обычно увеличивают время подбора. С другой стороны, более низкие значения «C» обычно приводят к большему количеству опорных векторов, что может увеличить время прогнозирования. Таким образом, снижение значения «C» включает в себя компромисс между временем подбора и временем прогнозирования. Мы также должны отметить, что небольшие различия в баллах возникают из-за случайного разделения процедуры перекрестной проверки. Эти ложные вариации можно сгладить, увеличив число итераций CV ``n_splits`` засчет затрат времени вычислений. Увеличение количества значений шагов C_range и ``gamma_range`` увеличит разрешение тепловой карты гиперпараметров.
# Utility class to move the midpoint of a colormap to be around
# the values of interest.
import numpy as np
from matplotlib.colors import Normalize
class MidpointNormalize(Normalize):
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
self.midpoint = midpoint
Normalize.__init__(self, vmin, vmax, clip)
def __call__(self, value, clip=None):
x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
return np.ma.masked_array(np.interp(value, x, y))
# Load and prepare data set
# dataset for grid search
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
# Dataset for decision function visualization: we only keep the first two
# features in X and sub-sample the dataset to keep only 2 classes and
# make it a binary classification problem.
X_2d = X[:, :2]
X_2d = X_2d[y > 0]
y_2d = y[y > 0]
y_2d -= 1
# It is usually a good idea to scale the data for SVM training.
# We are cheating a bit in this example in scaling all of the data,
# instead of fitting the transformation on the training set and
# just applying it on the test set.
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X = scaler.fit_transform(X)
X_2d = scaler.fit_transform(X_2d)
# Train classifiers
# For an initial search, a logarithmic grid with basis
# 10 is often helpful. Using a basis of 2, a finer
# tuning can be achieved but at a much higher cost.
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import GridSearchCV
C_range = np.logspace(-2, 10, 13)
gamma_range = np.logspace(-9, 3, 13)
param_grid = dict(gamma=gamma_range, C=C_range)
cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv)
grid.fit(X, y)
print(
"The best parameters are %s with a score of %0.2f"
% (grid.best_params_, grid.best_score_)
)
# Now we need to fit a classifier for all parameters in the 2d version
# (we use a smaller set of parameters here because it takes a while to train)
C_2d_range = [1e-2, 1, 1e2]
gamma_2d_range = [1e-1, 1, 1e1]
classifiers = []
for C in C_2d_range:
for gamma in gamma_2d_range:
clf = SVC(C=C, gamma=gamma)
clf.fit(X_2d, y_2d)
classifiers.append((C, gamma, clf))
# Visualization
# draw visualization of parameter effects
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 6))
xx, yy = np.meshgrid(np.linspace(-3, 3, 200), np.linspace(-3, 3, 200))
for k, (C, gamma, clf) in enumerate(classifiers):
# evaluate decision function in a grid
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# visualize decision function for these parameters
plt.subplot(len(C_2d_range), len(gamma_2d_range), k + 1)
plt.title("gamma=10^%d, C=10^%d" % (np.log10(gamma), np.log10(C)), size="medium")
# visualize parameter's effect on decision function
plt.pcolormesh(xx, yy, -Z, cmap=plt.cm.RdBu)
plt.scatter(X_2d[:, 0], X_2d[:, 1], c=y_2d, cmap=plt.cm.RdBu_r, edgecolors="k")
plt.xticks(())
plt.yticks(())
plt.axis("tight")
scores = grid.cv_results_["mean_test_score"].reshape(len(C_range), len(gamma_range))
# Draw heatmap of the validation accuracy as a function of gamma and C
# The score are encoded as colors with the hot colormap which varies from dark
# red to bright yellow. As the most interesting scores are all located in the
# 0.92 to 0.97 range we use a custom normalizer to set the mid-point to 0.92 so
# as to make it easier to visualize the small variations of score values in the
# interesting range while not brutally collapsing all the low score values to
# the same color.
plt.figure(figsize=(8, 6))
plt.subplots_adjust(left=0.2, right=0.95, bottom=0.15, top=0.95)
plt.imshow(
scores,
interpolation="nearest",
cmap=plt.cm.hot,
norm=MidpointNormalize(vmin=0.2, midpoint=0.92),
)
plt.xlabel("gamma")
plt.ylabel("C")
plt.colorbar()
plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
plt.yticks(np.arange(len(C_range)), C_range)
plt.title("Validation accuracy")
plt.show()
No comments:
Post a Comment