K-voudige kruisvalidatie in python (stap voor stap)


Om de prestaties van een model op een dataset te evalueren, moeten we meten hoe goed de voorspellingen van het model overeenkomen met de waargenomen gegevens.

Een veelgebruikte methode om dit te doen staat bekend als k-fold cross-validatie , waarbij de volgende aanpak wordt gebruikt:

1. Verdeel een dataset willekeurig in k groepen, of ‘vouwen’, van ongeveer gelijke grootte.

2. Kies een van de vouwen als bevestigingsset. Pas de sjabloon aan de resterende k-1-vouwen aan. Bereken de MSE-proef op de waarnemingen in de gespannen lamel.

3. Herhaal dit proces k keer, telkens met een andere set als uitsluitingsset.

4. Bereken de totale test-MSE als het gemiddelde van de k- test-MSE’s.

Deze zelfstudie biedt een stapsgewijs voorbeeld van hoe u k-voudige kruisvalidatie kunt uitvoeren voor een bepaald model in Python.

Stap 1: Laad de benodigde bibliotheken

Eerst laden we de functies en bibliotheken die nodig zijn voor dit voorbeeld:

 from sklearn. model_selection import train_test_split
from sklearn. model_selection import KFold
from sklearn. model_selection import cross_val_score
from sklearn. linear_model import LinearRegression
from numpy import means
from numpy import absolute
from numpy import sqrt
import pandas as pd

Stap 2: Creëer de gegevens

Vervolgens maken we een Panda DataFrame dat twee voorspellende variabelen bevat, x1 en x2 , en een enkele responsvariabele y.

 df = pd.DataFrame({' y ': [6, 8, 12, 14, 14, 15, 17, 22, 24, 23],
                   ' x1 ': [2, 5, 4, 3, 4, 6, 7, 5, 8, 9],
                   ' x2 ': [14, 12, 12, 13, 7, 8, 7, 4, 6, 5]})

Stap 3: Voer K-voudige kruisvalidatie uit

Vervolgens passen we een meervoudig lineair regressiemodel aan de dataset toe en voeren we LOOCV uit om de prestaties van het model te evalueren.

 #define predictor and response variables
X = df[[' x1 ', ' x2 ']]
y = df[' y ']

#define cross-validation method to use
cv = KFold ( n_splits = 10 , random_state = 1 , shuffle = True )

#build multiple linear regression model
model = LinearRegression()

#use k-fold CV to evaluate model
scores = cross_val_score(model, X, y, scoring=' neg_mean_absolute_error ',
                         cv=cv, n_jobs=-1)

#view mean absolute error
mean(absolute(scores))

3.6141267491803646

Uit het resultaat kunnen we zien dat de gemiddelde absolute fout (MAE) 3,614 was. Dat wil zeggen dat de gemiddelde absolute fout tussen de modelvoorspelling en de feitelijk waargenomen gegevens 3,614 bedraagt.

Over het algemeen geldt dat hoe lager de MAE, hoe beter een model feitelijke waarnemingen kan voorspellen.

Een andere veelgebruikte maatstaf om de prestaties van modellen te evalueren is de root mean square error (RMSE). De volgende code laat zien hoe u deze statistiek kunt berekenen met behulp van LOOCV:

 #define predictor and response variables
X = df[[' x1 ', ' x2 ']]
y = df[' y ']

#define cross-validation method to use
cv = KFold ( n_splits = 5 , random_state = 1 , shuffle = True ) 

#build multiple linear regression model
model = LinearRegression()

#use LOOCV to evaluate model
scores = cross_val_score(model, X, y, scoring=' neg_mean_squared_error ',
                         cv=cv, n_jobs=-1)

#view RMSE
sqrt(mean(absolute(scores)))

4.284373111711816

Uit het resultaat kunnen we zien dat de root mean square error (RMSE) 4,284 was.

Hoe lager de RMSE, hoe beter een model daadwerkelijke waarnemingen kan voorspellen.

In de praktijk passen we doorgaans verschillende modellen aan en vergelijken we de RMSE of MAE van elk model om te beslissen welk model de laagste testfoutpercentages oplevert en daarom het beste model is om te gebruiken.

Merk ook op dat we in dit voorbeeld ervoor kiezen om k=5 vouwen te gebruiken, maar u kunt elk gewenst aantal vouwen kiezen.

In de praktijk kiezen we doorgaans tussen 5 en 10 lagen, omdat dit het optimale aantal lagen blijkt te zijn dat betrouwbare testfoutpercentages oplevert.

Je kunt de volledige documentatie voor de KFold()-functie van sklearn hier vinden.

Aanvullende bronnen

Een inleiding tot K-fold kruisvalidatie
Een complete gids voor lineaire regressie in Python
Leave-One-Out kruisvalidatie in Python

Einen Kommentar hinzufügen

Deine E-Mail-Adresse wird nicht veröffentlicht. Erforderliche Felder sind mit * markiert