Hoe u kunt kruisvalideren voor modelprestaties in r


In de statistiek bouwen we vaak modellen om twee redenen:

  • Begrijp de relatie tussen een of meer voorspellende variabelen en een responsvariabele.
  • Gebruik een model om toekomstige waarnemingen te voorspellen.

Kruisvalidatie is nuttig om te schatten hoe goed een model toekomstige waarnemingen kan voorspellen.

We kunnen bijvoorbeeld een meervoudig lineair regressiemodel bouwen dat leeftijd en inkomen als voorspellende variabelen en de standaardstatus als responsvariabele gebruikt. In dit geval willen we het model misschien aan een dataset aanpassen en dat model vervolgens gebruiken om, op basis van het inkomen en de leeftijd van een nieuwe aanvrager, de waarschijnlijkheid te voorspellen dat hij of zij zijn lening niet zal kunnen afbetalen.

Om te bepalen of het model een sterk voorspellend vermogen heeft, moeten we het gebruiken om voorspellingen te doen op basis van gegevens die het nog nooit eerder heeft gezien. Hierdoor kunnen we de voorspellingsfout van het model schatten.

Kruisvalidatie gebruiken om de voorspellingsfout te schatten

Kruisvalidatie verwijst naar verschillende manieren waarop we de voorspellingsfout kunnen schatten. De algemene benadering van kruisvalidatie is:

1. Zet een bepaald aantal waarnemingen opzij in de dataset – doorgaans 15-25% van alle waarnemingen.
2. Pas het model aan (of ‘train’) op basis van de waarnemingen die we in de dataset bewaren.
3. Test hoe goed het model voorspellingen kan doen over waarnemingen die we niet hebben gebruikt om het model te trainen.

Het meten van de kwaliteit van een model

Wanneer we het aangepaste model gebruiken om voorspellingen te doen over nieuwe waarnemingen, kunnen we verschillende metrieken gebruiken om de kwaliteit van het model te meten, waaronder:

Meerdere R-kwadraat: Dit meet de sterkte van de lineaire relatie tussen de voorspellende variabelen en de responsvariabele. Een R-kwadraat veelvoud van 1 geeft een perfect lineair verband aan, terwijl een R-kwadraat veelvoud van 0 geen lineair verband aangeeft. Hoe hoger het R-kwadraat veelvoud, hoe waarschijnlijker het is dat de voorspellende variabelen de responsvariabele voorspellen.

Root Mean Square Error (RMSE): meet de gemiddelde voorspellingsfout die door het model wordt gemaakt bij het voorspellen van de waarde van een nieuwe waarneming. Dit is de gemiddelde afstand tussen de werkelijke waarde van een waarneming en de door het model voorspelde waarde. Lagere waarden voor RMSE duiden op een betere modelfit.

Mean Absolute Error (MAE): Dit is het gemiddelde absolute verschil tussen de werkelijke waarde van een waarneming en de door het model voorspelde waarde. Deze statistiek is over het algemeen minder gevoelig voor uitschieters dan RMSE. Lagere waarden voor MAE duiden op een betere modelfit.

Implementatie van vier verschillende kruisvalidatietechnieken in R

Vervolgens leggen we uit hoe u de volgende kruisvalidatietechnieken in R kunt implementeren:

1. Validatieset-aanpak
2. k-voudige kruisvalidatie
3. Laat kruisvalidatie buiten beschouwing
4. Herhaalde k-voudige kruisvalidatie

Om te illustreren hoe deze verschillende technieken kunnen worden gebruikt, zullen we een subset van de ingebouwde R-dataset van mtcars gebruiken:

 #define dataset
data <- mtcars[, c("mpg", "disp", "hp", "drat")]

#view first six rows of new data
head(data)

# mpg disp hp drat
#Mazda RX4 21.0 160 110 3.90
#Mazda RX4 Wag 21.0 160 110 3.90
#Datsun 710 22.8 108 93 3.85
#Hornet 4 Drive 21.4 258 110 3.08
#Hornet Sportabout 18.7 360 175 3.15
#Valiant 18.1 225 105 2.76

We zullen een meervoudig lineair regressiemodel bouwen met disp , hp en drat als voorspellende variabelen en mpg als responsvariabele.

Validatieset-aanpak

De validatiesetbenadering werkt als volgt:

1. Verdeel de gegevens in twee sets: de ene set wordt gebruikt om het model te trainen (dwz de modelparameters te schatten) en de andere set wordt gebruikt om het model te testen. Over het algemeen wordt de trainingsset gegenereerd door willekeurig 70-80% van de gegevens te selecteren, en de resterende 20-30% van de gegevens wordt gebruikt als testset.

2. Maak het model met behulp van de trainingsgegevensset.
3. Gebruik het model om voorspellingen te doen over de testsetgegevens.
4. Meet de modelkwaliteit met behulp van statistieken zoals R-kwadraat, RMSE en MAE.

Voorbeeld:

In het volgende voorbeeld wordt de gegevensset gebruikt die we hierboven hebben gedefinieerd. Eerst verdelen we de gegevens in
een trainingsset en een testset, waarbij 80% van de gegevens als trainingsset wordt gebruikt en de overige 20% van de gegevens als testset. Vervolgens bouwen we het model met behulp van de trainingsset. Vervolgens gebruiken we het model om voorspellingen te doen over de testset. Ten slotte meten we de kwaliteit van het model met behulp van R-kwadraat, RMSE en MAE.

 #load dplyr library used for data manipulation
library(dplyr)

#load caret library used for partitioning data into training and test set
library(caret)

#make this example reproducible
set.seed(0)

#define the dataset
data <- mtcars[, c("mpg", "disp", "hp", "drat")]

#split the dataset into a training set (80%) and test set (20%).
training_obs <- data$mpg %>% createDataPartition(p = 0.8, list = FALSE)

train <- data[training_obs, ]
test <- data[-training_obs, ]

# Build the linear regression model on the training set
model <- lm(mpg ~ ., data = train)

# Use the model to make predictions on the test set
predictions <- model %>% predict(test)

#Examine R-squared, RMSE, and MAE of predictions
data.frame(R_squared = R2(predictions, test$mpg),
           RMSE = RMSE(predictions, test$mpg),
           MAE = MAE(predictions, test$mpg))

#R_squared RMSE MAE
#1 0.9213066 1.876038 1.66614

Bij het vergelijken van verschillende modellen heeft het model met de laagste RMSE op de testset de voorkeur.

Voor- en nadelen van deze aanpak

Het voordeel van de validatiesetbenadering is dat deze eenvoudig en computationeel efficiënt is. Het nadeel is dat het model wordt gebouwd met slechts een deel van de totale gegevens. Als de data die we buiten de trainingsset laten, interessante of waardevolle informatie bevat, zal het model daar geen rekening mee houden.

k-voudige kruisvalidatiebenadering

De k-voudige kruisvalidatiebenadering werkt als volgt:

1. Verdeel de gegevens willekeurig in k “vouwen” of subsets (bijvoorbeeld 5 of 10 subsets).
2. Train het model op alle gegevens en laat slechts één subset weg.
3. Gebruik het model om voorspellingen te doen over de gegevens uit de weggelaten subset.
4. Herhaal dit proces totdat elk van de k-subsets als testset is gebruikt.
5 . Meet de kwaliteit van het model door de k-testfouten te middelen. Dit is bekend
als een kruisvalidatiefout.

Voorbeeld

In dit voorbeeld verdelen we de gegevens eerst in 5 subsets. Vervolgens passen we het model aan met behulp van alle gegevens, op een subset na. Vervolgens gebruiken we het model om voorspellingen te doen over de deelverzameling die is weggelaten en registreren we de testfout (met behulp van R-kwadraat, RMSE en MAE). We herhalen dit proces totdat elke subset als testset is gebruikt. Vervolgens berekenen we eenvoudig het gemiddelde van de 5 testfouten.

 #load dplyr library used for data manipulation
library(dplyr)

#load caret library used for partitioning data into training and test set
library(caret)

#make this example reproducible
set.seed(0)

#define the dataset
data <- mtcars[, c("mpg", "disp", "hp", "drat")]

#define the number of subsets (or "folds") to use
train_control <- trainControl(method = "cv", number = 5)

#train the model
model <- train(mpg ~ ., data = data, method = "lm", trControl = train_control)

#Summarize the results
print(model)

#Linear Regression 
#
#32 samples
#3 predictor
#
#No pre-processing
#Resampling: Cross-Validated (5 fold) 
#Summary of sample sizes: 26, 25, 26, 25, 26 
#Resampling results:
#
# RMSE Rsquared MAE     
#3.095501 0.7661981 2.467427
#
#Tuning parameter 'intercept' was held constant at a value of TRUE

Voor- en nadelen van deze aanpak

Het voordeel van de k-voudige kruisvalidatiebenadering ten opzichte van de validatiesetbenadering is dat het model verschillende keren wordt opgebouwd met elke keer verschillende stukjes gegevens, zodat we niet het risico lopen belangrijke gegevens weg te laten bij het bouwen van het model.

Het subjectieve deel van deze benadering is het kiezen van de waarde die voor k moet worden gebruikt, dat wil zeggen het aantal subsets waarin de gegevens moeten worden verdeeld. Over het algemeen leiden lagere k-waarden tot een grotere bias maar een lagere variabiliteit, terwijl hogere k-waarden leiden tot een lagere bias maar een grotere variabiliteit.

In de praktijk wordt k doorgaans gelijk aan 5 of 10 gekozen, omdat dit aantal deelverzamelingen de neiging heeft tegelijkertijd te veel vertekening en te veel variabiliteit te vermijden.

Leave One Out Cross-Validation (LOOCV)-aanpak

De LOOCV-aanpak werkt als volgt:

1. Bouw het model met behulp van op één na alle observaties in de dataset.
2. Gebruik het model om de waarde van de ontbrekende waarneming te voorspellen. Noteer de fout bij het testen van deze voorspelling.
3. Herhaal dit proces voor elke waarneming in de dataset.
4. Meet de kwaliteit van het model door alle voorspellingsfouten te middelen.

Voorbeeld

In het volgende voorbeeld ziet u hoe u LOOCV kunt uitvoeren voor dezelfde gegevensset als in de voorgaande voorbeelden:

 #load dplyr library used for data manipulation
library(dplyr)

#load caret library used for partitioning data into training and test set
library(caret)

#make this example reproducible
set.seed(0)

#define the dataset
data <- mtcars[, c("mpg", "disp", "hp", "drat")]

#specify that we want to use LOOCV
train_control <- trainControl( method = "LOOCV" )

#train the model
model <- train(mpg ~ ., data = data, method = "lm", trControl = train_control)

#summarize the results
print(model)

#Linear Regression 
#
#32 samples
#3 predictor
#
#No pre-processing
#Resampling: Leave-One-Out Cross-Validation 
#Summary of sample sizes: 31, 31, 31, 31, 31, 31, ... 
#Resampling results:
#
# RMSE Rsquared MAE     
#3.168763 0.7170704 2.503544
#
#Tuning parameter 'intercept' was held constant at a value of TRUE

Voor- en nadelen van deze aanpak

Het voordeel van LOOCV is dat we alle datapunten gebruiken, waardoor mogelijke vertekeningen over het algemeen worden verminderd. Omdat we het model echter gebruiken om de waarde van elke waarneming te voorspellen, zou dit kunnen leiden tot een grotere variabiliteit in de voorspellingsfout.

Een ander nadeel van deze aanpak is dat deze in een zo groot aantal modellen moet passen dat deze inefficiënt en rekenintensief kan worden.

Herhaalde k-voudige kruisvalidatiebenadering

We kunnen herhaalde k-voudige kruisvalidatie uitvoeren door simpelweg meerdere keren k-voudige kruisvalidatie uit te voeren. De uiteindelijke fout is de gemiddelde fout van het aantal herhalingen.

In het volgende voorbeeld wordt een vijfvoudige kruisvalidatie uitgevoerd, vier keer herhaald:

 #load dplyr library used for data manipulation
library(dplyr)

#load caret library used for partitioning data into training and test set
library(caret)

#make this example reproducible
set.seed(0)

#define the dataset
data <- mtcars[, c("mpg", "disp", "hp", "drat")]

#define the number of subsets to use and number of times to repeat k-fold CV
train_control <- trainControl(method = "repeatedcv", number = 5, repeats = 4 )

#train the model
model <- train(mpg ~ ., data = data, method = "lm", trControl = train_control)

#summarize the results
print(model)

#Linear Regression 
#
#32 samples
#3 predictor
#
#No pre-processing
#Resampling: Cross-Validated (5 fold, repeated 4 times) 
#Summary of sample sizes: 26, 25, 26, 25, 26, 25, ... 
#Resampling results:
#
# RMSE Rsquared MAE     
#3.176339 0.7909337 2.559131
#
#Tuning parameter 'intercept' was held constant at a value of TRUE

Voor- en nadelen van deze aanpak

Het voordeel van de herhaalde k-voudige kruisvalidatiebenadering is dat voor elke herhaling de gegevens worden opgesplitst in enigszins verschillende subsets, wat een nog onbevooroordeelde schatting van de voorspellingsfout van het model zou moeten opleveren. Het nadeel van deze aanpak is dat deze rekenintensief kan zijn, omdat we het modelaanpassingsproces verschillende keren moeten herhalen.

Hoe u het aantal vouwen bij kruisvalidatie kiest

Het meest subjectieve onderdeel van kruisvalidatie is het bepalen hoeveel vouwen (dwz subsets) er gebruikt moeten worden. Over het algemeen geldt dat hoe kleiner het aantal vouwen is, des te vertekender de foutschattingen zijn, maar des te minder variabel deze zullen zijn. Omgekeerd geldt: hoe hoger het aantal vouwen, hoe minder vertekend de foutschattingen, maar hoe variabeler ze zullen zijn.

Ook is het belangrijk om rekening te houden met de rekentijd. Voor elke vouw moet je een nieuw patroon trainen, en hoewel dit een langzaam proces is, kan het lang duren als je een groot aantal vouwen kiest.

In de praktijk wordt kruisvalidatie meestal uitgevoerd met 5 of 10 vouwen, omdat dit een goed evenwicht biedt tussen variabiliteit en vertekening, terwijl het ook computationeel efficiënt is.

Hoe u een model kiest na het uitvoeren van kruisvalidatie

Kruisvalidatie wordt gebruikt om de voorspellingsfout van een model te evalueren. Dit kan ons helpen kiezen tussen twee of meer verschillende modellen door te benadrukken welk model de laagste voorspellingsfout heeft (gebaseerd op RMSE, R-kwadraat, enz.).

Nadat we kruisvalidatie hebben gebruikt om het beste model te selecteren, gebruiken we alle beschikbare gegevens om bij het gekozen model te passen. We gebruiken niet de daadwerkelijke modelinstanties die we hebben getraind tijdens de kruisvalidatie voor ons uiteindelijke model.

We kunnen bijvoorbeeld 5-voudige kruisvalidatie gebruiken om te bepalen welk model beter kan worden gebruikt tussen twee verschillende regressiemodellen. Zodra we echter hebben vastgesteld welk model het beste kan worden gebruikt, gebruiken we alle gegevens om in het uiteindelijke model te passen. Met andere woorden: we vergeten geen enkele vouw bij het bouwen van het uiteindelijke model.

Einen Kommentar hinzufügen

Deine E-Mail-Adresse wird nicht veröffentlicht. Erforderliche Felder sind mit * markiert