K-voudige kruisvalidatie in r (stap voor stap)
Om de prestaties van een model op een dataset te evalueren, moeten we meten hoe goed de voorspellingen van het model overeenkomen met de waargenomen gegevens.
Een veelgebruikte methode om dit te doen staat bekend als k-fold cross-validatie , waarbij de volgende aanpak wordt gebruikt:
1. Verdeel een dataset willekeurig in k groepen, of ‘vouwen’, van ongeveer gelijke grootte.
2. Kies een van de vouwen als bevestigingsset. Pas de sjabloon aan de resterende k-1-vouwen aan. Bereken de MSE-proef op de waarnemingen in de gespannen lamel.
3. Herhaal dit proces k keer, telkens met een andere set als uitsluitingsset.
4. Bereken de totale test-MSE als het gemiddelde van de k- test-MSE’s.
De eenvoudigste manier om k-voudige kruisvalidatie uit te voeren in R is door de functie trainControl() uit de caret- bibliotheek in R te gebruiken.
Deze tutorial geeft een snel voorbeeld van hoe u deze functie kunt gebruiken om k-voudige kruisvalidatie uit te voeren voor een bepaald model in R.
Voorbeeld: K-voudige kruisvalidatie in R
Stel dat we de volgende dataset in R hebben:
#create data frame df <- data.frame(y=c(6, 8, 12, 14, 14, 15, 17, 22, 24, 23), x1=c(2, 5, 4, 3, 4, 6, 7, 5, 8, 9), x2=c(14, 12, 12, 13, 7, 8, 7, 4, 6, 5)) #view data frame df y x1 x2 6 2 14 8 5 12 12 4 12 14 3 13 14 4 7 15 6 8 17 7 7 22 5 4 24 8 6 23 9 5
De volgende code laat zien hoe u een meervoudig lineair regressiemodel aan deze gegevensset in R kunt aanpassen en k-voudige kruisvalidatie met k = 5 keer kunt uitvoeren om de prestaties van het model te evalueren:
library (caret) #specify the cross-validation method ctrl <- trainControl(method = " cv ", number = 5) #fit a regression model and use k-fold CV to evaluate performance model <- train(y ~ x1 + x2, data = df, method = " lm ", trControl = ctrl) #view summary of k-fold CV print(model) Linear Regression 10 samples 2 predictors No pre-processing Resampling: Cross-Validated (5 fold) Summary of sample sizes: 8, 8, 8, 8, 8 Resampling results: RMSE Rsquared MAE 3.018979 1 2.882348 Tuning parameter 'intercept' was held constant at a value of TRUE
Zo interpreteert u het resultaat:
- Er heeft geen voorbewerking plaatsgevonden. Dat wil zeggen dat we de gegevens op geen enkele manier hebben geschaald voordat we de modellen hebben aangepast.
- De resamplingmethode die we gebruikten om het model te evalueren, was vijfvoudige kruisvalidatie.
- De steekproefomvang voor elke trainingsset was 8.
- RMSE: wortelgemiddelde kwadratische fout. Dit meet het gemiddelde verschil tussen de voorspellingen van het model en de daadwerkelijke waarnemingen. Hoe lager de RMSE, hoe nauwkeuriger een model daadwerkelijke waarnemingen kan voorspellen.
- Rsquared: Dit is een maatstaf voor de correlatie tussen voorspellingen van het model en feitelijke waarnemingen. Hoe hoger het R-kwadraat, hoe nauwkeuriger een model feitelijke waarnemingen kan voorspellen.
- MAE: De gemiddelde absolute fout. Dit is het gemiddelde absolute verschil tussen de voorspellingen van het model en de feitelijke waarnemingen. Hoe lager de MAE, hoe nauwkeuriger een model daadwerkelijke waarnemingen kan voorspellen.
Elk van de drie metingen in het resultaat (RMSE, R-kwadraat en MAE) geeft ons een idee van de prestaties van het model op basis van niet-gepubliceerde gegevens.
In de praktijk passen we doorgaans verschillende modellen toe en vergelijken we de drie meetgegevens die de hier gepresenteerde resultaten opleveren om te beslissen welk model de laagste testfoutenpercentages oplevert en daarom het beste model is om te gebruiken.
We kunnen de volgende code gebruiken om de uiteindelijke pasvorm van het model te onderzoeken:
#view final model
model$finalModel
Call:
lm(formula = .outcome ~ ., data = dat)
Coefficients:
(Intercept) x1 x2
21.2672 0.7803 -1.1253
Het uiteindelijke model blijkt te zijn:
y = 21,2672 + 0,7803*(x 1 ) – 1,12538(x 2 )
We kunnen de volgende code gebruiken om de modelvoorspellingen voor elke vouw weer te geven:
#view predictions for each fold
model$resample
RMSE Rsquared MAE Resample
1 4.808773 1 3.544494 Fold1
2 3.464675 1 3.366812 Fold2
3 6.281255 1 6.280702 Fold3
4 3.759222 1 3.573883 Fold4
5 1.741127 1 1.679767 Fold5
Merk op dat we in dit voorbeeld ervoor kiezen om k=5 vouwen te gebruiken, maar u kunt elk gewenst aantal vouwen kiezen. In de praktijk kiezen we doorgaans tussen 5 en 10 lagen, omdat dit het optimale aantal lagen blijkt te zijn dat betrouwbare testfoutpercentages oplevert.