Walidacja krzyżowa k-fold w r (krok po kroku)
Aby ocenić wydajność modelu na zbiorze danych, musimy zmierzyć, jak dobrze przewidywania dokonane przez model odpowiadają obserwowanym danym.
Powszechnie stosowaną metodą jest k-krotna walidacja krzyżowa , która wykorzystuje następujące podejście:
1. Losowo podziel zbiór danych na k grup, czyli „fałd”, o mniej więcej równej wielkości.
2. Wybierz jedną z zagięć jako zestaw utwierdzający. Dopasuj szablon do pozostałych zakładek k-1. Oblicz test MSE na podstawie obserwacji w naprężonej warstwie.
3. Powtórz ten proces k razy, za każdym razem używając innego zbioru jako zbioru wykluczającego.
4. Oblicz ogólny test MSE jako średnią k MSE testu.
Najłatwiejszym sposobem przeprowadzenia k-krotnej walidacji krzyżowej w R jest użycie funkcji trainControl() z biblioteki caret w R.
W tym samouczku przedstawiono krótki przykład użycia tej funkcji do przeprowadzenia k-krotnej walidacji krzyżowej dla danego modelu w języku R.
Przykład: weryfikacja krzyżowa typu K w R
Załóżmy, że mamy następujący zbiór danych w R:
#create data frame df <- data.frame(y=c(6, 8, 12, 14, 14, 15, 17, 22, 24, 23), x1=c(2, 5, 4, 3, 4, 6, 7, 5, 8, 9), x2=c(14, 12, 12, 13, 7, 8, 7, 4, 6, 5)) #view data frame df y x1 x2 6 2 14 8 5 12 12 4 12 14 3 13 14 4 7 15 6 8 17 7 7 22 5 4 24 8 6 23 9 5
Poniższy kod pokazuje, jak dopasować model regresji liniowej do tego zbioru danych w R i przeprowadzić k-krotną weryfikację krzyżową z k = 5 razy, aby ocenić wydajność modelu:
library (caret) #specify the cross-validation method ctrl <- trainControl(method = " cv ", number = 5) #fit a regression model and use k-fold CV to evaluate performance model <- train(y ~ x1 + x2, data = df, method = " lm ", trControl = ctrl) #view summary of k-fold CV print(model) Linear Regression 10 samples 2 predictors No pre-processing Resampling: Cross-Validated (5 fold) Summary of sample sizes: 8, 8, 8, 8, 8 Resampling results: RMSE Rsquared MAE 3.018979 1 2.882348 Tuning parameter 'intercept' was held constant at a value of TRUE
Oto jak zinterpretować wynik:
- Nie przeprowadzono żadnego wstępnego przetwarzania. Oznacza to, że przed dopasowaniem modeli nie skalowaliśmy danych w żaden sposób.
- Metodą ponownego próbkowania, którą zastosowaliśmy do oceny modelu, była 5-krotna walidacja krzyżowa.
- Wielkość próby dla każdego zestawu treningowego wynosiła 8.
- RMSE: średni błąd kwadratowy. Mierzy średnią różnicę między przewidywaniami dokonanymi przez model a rzeczywistymi obserwacjami. Im niższy RMSE, tym dokładniej model może przewidzieć rzeczywiste obserwacje.
- Rkwadrat: Jest to miara korelacji między przewidywaniami modelu a rzeczywistymi obserwacjami. Im wyższy współczynnik R-kwadrat, tym dokładniej model może przewidzieć rzeczywiste obserwacje.
- MAE: Średni błąd bezwzględny. Jest to średnia bezwzględna różnica między przewidywaniami modelu a rzeczywistymi obserwacjami. Im niższy MAE, tym dokładniej model może przewidzieć rzeczywiste obserwacje.
Każdy z trzech pomiarów podanych w wyniku (RMSE, R-kwadrat i MAE) daje nam wyobrażenie o działaniu modelu na niepublikowanych danych.
W praktyce zazwyczaj dopasowujemy kilka różnych modeli i porównujemy trzy metryki dostarczone przez przedstawione tutaj wyniki, aby zdecydować, który model daje najniższy poziom błędów testowych i dlatego jest najlepszym modelem do użycia.
Możemy użyć następującego kodu, aby sprawdzić ostateczne dopasowanie modelu:
#view final model
model$finalModel
Call:
lm(formula = .outcome ~ ., data = dat)
Coefficients:
(Intercept) x1 x2
21.2672 0.7803 -1.1253
Ostateczny model wygląda następująco:
y = 21,2672 + 0,7803*(x 1 ) – 1,12538 (x 2 )
Możemy użyć następującego kodu, aby wyświetlić przewidywania modelu wykonane dla każdego zagięcia:
#view predictions for each fold
model$resample
RMSE Rsquared MAE Resample
1 4.808773 1 3.544494 Fold1
2 3.464675 1 3.366812 Fold2
3 6.281255 1 6.280702 Fold3
4 3.759222 1 3.573883 Fold4
5 1.741127 1 1.679767 Fold5
Zauważ, że w tym przykładzie użyliśmy k=5 fałd, ale możesz wybrać dowolną liczbę fałd. W praktyce zazwyczaj wybieramy pomiędzy 5 a 10 warstwami, ponieważ okazuje się, że jest to optymalna liczba warstw, która zapewnia wiarygodny poziom błędów testowych.