Odp.: jak używać traincontrol do kontrolowania parametrów treningu


Aby ocenić, jak dobrze model jest w stanie dopasować się do zbioru danych, musimy przeanalizować jego działanie na podstawie obserwacji, których nigdy wcześniej nie widział.

Jednym z najczęstszych sposobów osiągnięcia tego jest użycie k-krotnej walidacji krzyżowej , która wykorzystuje następujące podejście:

1. Losowo podziel zbiór danych na k grup, czyli „fałd”, o mniej więcej równej wielkości.

2. Wybierz jedną z zagięć jako zestaw utwierdzający. Dopasuj szablon do pozostałych zakładek k-1. Oblicz test MSE na podstawie obserwacji w naprężonej warstwie.

3. Powtórz ten proces k razy, za każdym razem używając innego zbioru jako zbioru wykluczającego.

4. Oblicz ogólny test MSE jako średnią k MSE testu.

Najłatwiejszym sposobem przeprowadzenia k-krotnej walidacji krzyżowej w R jest użycie funkcji trainControl() i train() z biblioteki caret w R.

Funkcja trainControl() służy do określania parametrów szkoleniowych (np. rodzaj stosowanej walidacji krzyżowej, liczba powtórzeń itp.), a funkcja train() służy do faktycznego dopasowania modelu do danych. .

Poniższy przykład pokazuje, jak w praktyce używać funkcji trainControl() i train() .

Przykład: Jak używać trainControl() w R

Załóżmy, że mamy następujący zbiór danych w R:

 #create data frame
df <- data.frame(y=c(6, 8, 12, 14, 14, 15, 17, 22, 24, 23),
                 x1=c(2, 5, 4, 3, 4, 6, 7, 5, 8, 9),
                 x2=c(14, 12, 12, 13, 7, 8, 7, 4, 6, 5))

#view data frame
df

y x1 x2
6 2 14
8 5 12
12 4 12
14 3 13
14 4 7
15 6 8
17 7 7
22 5 4
24 8 6
23 9 5

Załóżmy teraz, że używamy funkcji lm() , aby dopasować model regresji liniowej do tego zbioru danych, używając x1 i x2 jako zmiennych predykcyjnych oraz y jako zmiennej odpowiedzi:

 #fit multiple linear regression model to data
fit <- lm(y ~ x1 + x2, data=df)

#view model summary
summary(fit)

Call:
lm(formula = y ~ x1 + x2, data = df)

Residuals:
    Min 1Q Median 3Q Max 
-3.6650 -1.9228 -0.3684 1.2783 5.0208 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)  
(Intercept) 21.2672 6.9927 3.041 0.0188 *
x1 0.7803 0.6942 1.124 0.2981  
x2 -1.1253 0.4251 -2.647 0.0331 *
---
Significant. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 3.093 on 7 degrees of freedom
Multiple R-squared: 0.801, Adjusted R-squared: 0.7441 
F-statistic: 14.09 on 2 and 7 DF, p-value: 0.003516

Wykorzystując współczynniki z wyników modelu, możemy napisać dopasowany model regresji:

y = 21,2672 + 0,7803*(x 1 ) – 1,1253 (x 2 )

Aby zorientować się, jak dobrze ten model radzi sobie z niewidzialnymi obserwacjami , możemy zastosować k-krotną weryfikację krzyżową.

Poniższy kod pokazuje, jak używać funkcji trainControl() pakietu cart do określenia k-krotnej walidacji krzyżowej ( method=”cv” ), która wykorzystuje 5 przypadków ( number=5 ).

Następnie przekazujemy tę funkcję trainControl() do funkcji train() , aby faktycznie przeprowadzić k-krotną weryfikację krzyżową:

 library (caret)

#specify the cross-validation method
ctrl <- trainControl(method = " cv ", number = 5 )

#fit a regression model and use k-fold CV to evaluate performance
model <- train(y ~ x1 + x2, data = df, method = " lm ", trControl = ctrl)

#view summary of k-fold CV               
print (model)

Linear Regression 

10 samples
 2 predictors

No pre-processing
Resampling: Cross-Validated (5 fold) 
Summary of sample sizes: 8, 8, 8, 8, 8 
Resampling results:

  RMSE Rsquared MAE     
  3.612302 1 3.232153

Tuning parameter 'intercept' was held constant at a value of TRUE

Z wyniku widać, że model był dopasowywany 5 razy, za każdym razem na próbie liczącej 8 obserwacji.

Za każdym razem model wykorzystywano następnie do przewidywania wartości 2 zachowanych obserwacji i za każdym razem obliczano następujące metryki:

  • RMSE: średni błąd kwadratowy. Mierzy średnią różnicę między przewidywaniami dokonanymi przez model a rzeczywistymi obserwacjami. Im niższy RMSE, tym dokładniej model może przewidzieć rzeczywiste obserwacje.
  • MAE: Średni błąd bezwzględny. Jest to średnia bezwzględna różnica między przewidywaniami modelu a rzeczywistymi obserwacjami. Im niższy MAE, tym dokładniej model może przewidzieć rzeczywiste obserwacje.

W wyniku wyświetlana jest średnia wartości RMSE i MAE dla pięciu składników:

  • RMSE: 3,612302
  • MAE: 3,232153

Metryki te dają nam wyobrażenie o wydajności modelu na nowych danych.

W praktyce zazwyczaj dopasowujemy kilka różnych modeli i porównujemy te dane, aby określić, który model sprawdza się najlepiej w przypadku niewidocznych danych.

Na przykład moglibyśmy dopasować model regresji wielomianowej i przeprowadzić na nim k-krotną weryfikację krzyżową, aby zobaczyć, jak metryki RMSE i MAE wypadają w porównaniu z modelem wielokrotnej regresji liniowej.

Uwaga nr 1: W tym przykładzie zdecydowaliśmy się użyć k=5 fałd, ale możesz wybrać dowolną liczbę fałd. W praktyce zazwyczaj wybieramy pomiędzy 5 a 10 warstwami, ponieważ okazuje się, że jest to optymalna liczba warstw, która zapewnia wiarygodny poziom błędów testowych.

Uwaga nr 2 : Funkcja trainControl() akceptuje wiele potencjalnych argumentów. Pełną dokumentację tej funkcji można znaleźć tutaj .

Dodatkowe zasoby

Poniższe samouczki zawierają dodatkowe informacje na temat modeli szkoleniowych:

Wprowadzenie do walidacji krzyżowej typu K
Wprowadzenie do walidacji krzyżowej typu Leave-One-Out
Czym jest nadmierne dopasowanie w uczeniu maszynowym?

Dodaj komentarz

Twój adres e-mail nie zostanie opublikowany. Wymagane pola są oznaczone *