Перехресна перевірка k-fold у python (крок за кроком)


Щоб оцінити продуктивність моделі на наборі даних, нам потрібно виміряти, наскільки прогнози, зроблені моделлю, відповідають даним спостереження.

Зазвичай використовуваний метод для цього відомий як k-кратна перехресна перевірка , яка використовує такий підхід:

1. Випадково розділіть набір даних на k груп, або «згорток», приблизно однакового розміру.

2. Виберіть одну зі складок як обмежувальний комплект. Підігніть шаблон до решти k-1 складок. Розрахуйте випробування MSE на основі спостережень у шарі, який був натягнутий.

3. Повторіть цей процес k разів, кожного разу використовуючи інший набір як набір виключень.

4. Обчисліть загальну тестову MSE як середнє k тестових MSE.

Цей підручник надає покроковий приклад того, як виконати k-кратну перехресну перевірку для даної моделі в Python.

Крок 1: Завантажте необхідні бібліотеки

Спочатку ми завантажимо функції та бібліотеки, необхідні для цього прикладу:

 from sklearn. model_selection import train_test_split
from sklearn. model_selection import KFold
from sklearn. model_selection import cross_val_score
from sklearn. linear_model import LinearRegression
from numpy import means
from numpy import absolute
from numpy import sqrt
import pandas as pd

Крок 2: Створіть дані

Далі ми створимо pandas DataFrame, який містить дві змінні предиктора, x1 і x2 , і одну змінну відповіді y.

 df = pd.DataFrame({' y ': [6, 8, 12, 14, 14, 15, 17, 22, 24, 23],
                   ' x1 ': [2, 5, 4, 3, 4, 6, 7, 5, 8, 9],
                   ' x2 ': [14, 12, 12, 13, 7, 8, 7, 4, 6, 5]})

Крок 3: Виконайте перехресну перевірку K-згортки

Далі ми підберемо модель множинної лінійної регресії до набору даних і виконаємо LOOCV, щоб оцінити продуктивність моделі.

 #define predictor and response variables
X = df[[' x1 ', ' x2 ']]
y = df[' y ']

#define cross-validation method to use
cv = KFold ( n_splits = 10 , random_state = 1 , shuffle = True )

#build multiple linear regression model
model = LinearRegression()

#use k-fold CV to evaluate model
scores = cross_val_score(model, X, y, scoring=' neg_mean_absolute_error ',
                         cv=cv, n_jobs=-1)

#view mean absolute error
mean(absolute(scores))

3.6141267491803646

З результату ми бачимо, що середня абсолютна похибка (MAE) становила 3,614 . Тобто середня абсолютна похибка між прогнозом моделі та фактично спостережуваними даними становить 3,614.

Загалом, чим нижче MAE, тим краще модель здатна передбачити фактичні спостереження.

Іншим часто використовуваним показником для оцінки продуктивності моделі є середньоквадратична помилка (RMSE). Наступний код показує, як обчислити цей показник за допомогою LOOCV:

 #define predictor and response variables
X = df[[' x1 ', ' x2 ']]
y = df[' y ']

#define cross-validation method to use
cv = KFold ( n_splits = 5 , random_state = 1 , shuffle = True ) 

#build multiple linear regression model
model = LinearRegression()

#use LOOCV to evaluate model
scores = cross_val_score(model, X, y, scoring=' neg_mean_squared_error ',
                         cv=cv, n_jobs=-1)

#view RMSE
sqrt(mean(absolute(scores)))

4.284373111711816

З результату ми бачимо, що середня квадратична помилка (RMSE) становила 4,284 .

Чим нижче RMSE, тим краще модель здатна передбачити фактичні спостереження.

На практиці ми зазвичай підбираємо кілька різних моделей і порівнюємо RMSE або MAE кожної моделі, щоб вирішити, яка модель дає найнижчий рівень помилок тестування і, отже, є найкращою для використання.

Також зауважте, що в цьому прикладі ми вирішили використовувати k=5 згорток, але ви можете вибрати будь-яку кількість згорток.

На практиці ми зазвичай обираємо від 5 до 10 шарів, оскільки це оптимальна кількість шарів, яка забезпечує надійний рівень помилок тесту.

Ви можете знайти повну документацію для функції KFold() sklearn тут .

Додаткові ресурси

Вступ до K-кратної перехресної перевірки
Повний посібник із лінійної регресії в Python
Перехресна перевірка Leave-One-Out у Python

Додати коментар

Ваша e-mail адреса не оприлюднюватиметься. Обов’язкові поля позначені *