Leave-one-out-kreuzvalidierung in python (mit beispielen)


Um die Leistung eines Modells anhand eines Datensatzes zu bewerten, müssen wir messen, wie gut die vom Modell gemachten Vorhersagen mit den beobachteten Daten übereinstimmen.

Eine hierfür häufig verwendete Methode ist die sogenannte Leave-One-Out Cross-Validation (LOOCV) , die den folgenden Ansatz verwendet:

1. Teilen Sie einen Datensatz in einen Trainingssatz und einen Testsatz auf und verwenden Sie dabei alle Beobachtungen bis auf eine als Teil des Trainingssatzes.

2. Erstellen Sie ein Modell, das nur Daten aus dem Trainingssatz verwendet.

3. Verwenden Sie das Modell, um den Antwortwert der vom Modell ausgeschlossenen Beobachtung vorherzusagen und den mittleren quadratischen Fehler (MSE) zu berechnen.

4. Wiederholen Sie diesen Vorgang n -mal. Berechnen Sie den Test-MSE als Durchschnitt aller Test-MSE.

Dieses Tutorial bietet ein schrittweises Beispiel für die Ausführung von LOOCV für ein bestimmtes Modell in Python.

Schritt 1: Laden Sie die erforderlichen Bibliotheken

Zuerst laden wir die für dieses Beispiel benötigten Funktionen und Bibliotheken:

 from sklearn. model_selection import train_test_split
from sklearn. model_selection import LeaveOneOut
from sklearn. model_selection import cross_val_score
from sklearn. linear_model import LinearRegression
from numpy import means
from numpy import absolute
from numpy import sqrt
import pandas as pd

Schritt 2: Erstellen Sie die Daten

Als Nächstes erstellen wir einen Pandas-DataFrame, der zwei Prädiktorvariablen x1 und x2 sowie eine einzelne Antwortvariable y enthält.

 df = pd.DataFrame({' y ': [6, 8, 12, 14, 14, 15, 17, 22, 24, 23],
                   ' x1 ': [2, 5, 4, 3, 4, 6, 7, 5, 8, 9],
                   ' x2 ': [14, 12, 12, 13, 7, 8, 7, 4, 6, 5]})

Schritt 3: Führen Sie eine Leave-One-Out-Kreuzvalidierung durch

Als Nächstes passen wir ein multiples lineares Regressionsmodell an den Datensatz an und führen LOOCV durch, um die Leistung des Modells zu bewerten.

 #define predictor and response variables
X = df[[' x1 ', ' x2 ']]
y = df[' y ']

#define cross-validation method to use
cv = LeaveOneOut()

#build multiple linear regression model
model = LinearRegression()

#use LOOCV to evaluate model
scores = cross_val_score(model, X, y, scoring=' neg_mean_absolute_error ',
                         cv=cv, n_jobs=-1)

#view mean absolute error
mean(absolute(scores))

3.1461548083469726

Aus dem Ergebnis können wir ersehen, dass der mittlere absolute Fehler (MAE) 3,146 betrug. Das heißt, der durchschnittliche absolute Fehler zwischen der Modellvorhersage und den tatsächlich beobachteten Daten beträgt 3,146.

Im Allgemeinen gilt: Je niedriger der MAE, desto besser kann ein Modell tatsächliche Beobachtungen vorhersagen.

Eine weitere häufig verwendete Metrik zur Bewertung der Modellleistung ist der quadratische Mittelfehler (Root Mean Square Error, RMSE). Der folgende Code zeigt, wie diese Metrik mithilfe von LOOCV berechnet wird:

 #define predictor and response variables
X = df[[' x1 ', ' x2 ']]
y = df[' y ']

#define cross-validation method to use
cv = LeaveOneOut()

#build multiple linear regression model
model = LinearRegression()

#use LOOCV to evaluate model
scores = cross_val_score(model, X, y, scoring=' neg_mean_squared_error ',
                         cv=cv, n_jobs=-1)

#view RMSE
sqrt(mean(absolute(scores)))

3.619456476385567

Aus dem Ergebnis können wir ersehen, dass der quadratische Mittelfehler (RMSE) 3,619 betrug. Je niedriger der RMSE, desto besser kann ein Modell tatsächliche Beobachtungen vorhersagen.

In der Praxis passen wir in der Regel mehrere unterschiedliche Modelle an und vergleichen den RMSE oder MAE jedes Modells, um zu entscheiden, welches Modell die niedrigsten Testfehlerraten erzeugt und daher das beste zu verwendende Modell ist.

Zusätzliche Ressourcen

Eine kurze Einführung in die Leave-One-Out-Kreuzvalidierung (LOOCV)
Eine vollständige Anleitung zur linearen Regression in Python

Einen Kommentar hinzufügen

Deine E-Mail-Adresse wird nicht veröffentlicht. Erforderliche Felder sind mit * markiert