R : Comment utiliser trainControl pour contrôler les paramètres de formation



Pour évaluer dans quelle mesure un modèle est capable de s’adapter à un ensemble de données, nous devons analyser ses performances sur des observations qu’il n’a jamais vues auparavant.

L’un des moyens les plus courants d’y parvenir consiste à utiliser la validation croisée k-fold , qui utilise l’approche suivante :

1. Divisez aléatoirement un ensemble de données en k groupes, ou « plis », de taille à peu près égale.

2. Choisissez l’un des plis comme ensemble de retenue. Ajustez le modèle sur les plis k-1 restants. Calculez le test MSE sur les observations dans le pli qui a été tendu.

3. Répétez ce processus k fois, en utilisant à chaque fois un ensemble différent comme ensemble d’exclusion.

4. Calculez le MSE global du test comme étant la moyenne des k MSE du test.

Le moyen le plus simple d’effectuer une validation croisée k fois dans R consiste à utiliser les fonctions trainControl() et train() de la bibliothèque caret dans R.

La fonction trainControl() est utilisée pour spécifier les paramètres de formation (par exemple le type de validation croisée à utiliser, le nombre de plis à utiliser, etc.) et la fonction train() est utilisée pour adapter réellement le modèle aux données. .

L’exemple suivant montre comment utiliser les fonctions trainControl() et train() dans la pratique.

Exemple : Comment utiliser trainControl() dans R

Supposons que nous ayons l’ensemble de données suivant dans R :

#create data frame
df <- data.frame(y=c(6, 8, 12, 14, 14, 15, 17, 22, 24, 23),
                 x1=c(2, 5, 4, 3, 4, 6, 7, 5, 8, 9),
                 x2=c(14, 12, 12, 13, 7, 8, 7, 4, 6, 5))

#view data frame
df

y	x1	x2
6	2	14
8	5	12
12	4	12
14	3	13
14	4	7
15	6	8
17	7	7
22	5	4
24	8	6
23	9	5

Supposons maintenant que nous utilisions la fonction lm() pour ajuster un modèle de régression linéaire multiple à cet ensemble de données, en utilisant x1 et x2 comme variables prédictives et y comme variable de réponse :

#fit multiple linear regression model to data
fit <- lm(y ~ x1 + x2, data=df)

#view model summary
summary(fit)

Call:
lm(formula = y ~ x1 + x2, data = df)

Residuals:
    Min      1Q  Median      3Q     Max 
-3.6650 -1.9228 -0.3684  1.2783  5.0208 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)  
(Intercept)  21.2672     6.9927   3.041   0.0188 *
x1            0.7803     0.6942   1.124   0.2981  
x2           -1.1253     0.4251  -2.647   0.0331 *
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 3.093 on 7 degrees of freedom
Multiple R-squared:  0.801,	Adjusted R-squared:  0.7441 
F-statistic: 14.09 on 2 and 7 DF,  p-value: 0.003516

En utilisant les coefficients dans la sortie du modèle, nous pouvons écrire le modèle de régression ajusté :

y = 21,2672 + 0,7803*(x 1 ) – 1,1253(x 2 )

Pour avoir une idée des performances de ce modèle sur des observations invisibles, nous pouvons utiliser la validation croisée k-fold.

Le code suivant montre comment utiliser la fonction trainControl() du package caret pour spécifier une validation croisée k-fold ( méthode = « cv » ) qui utilise 5 plis ( number = 5 ).

Nous transmettons ensuite cette fonction trainControl() à la fonction train() pour effectuer réellement la validation croisée k-fold :

library(caret)

#specify the cross-validation method
ctrl <- trainControl(method = "cv", number = 5)

#fit a regression model and use k-fold CV to evaluate performance
model <- train(y ~ x1 + x2, data = df, method = "lm", trControl = ctrl)

#view summary of k-fold CV               
print(model)

Linear Regression 

10 samples
 2 predictor

No pre-processing
Resampling: Cross-Validated (5 fold) 
Summary of sample sizes: 8, 8, 8, 8, 8 
Resampling results:

  RMSE      Rsquared  MAE     
  3.612302  1         3.232153

Tuning parameter 'intercept' was held constant at a value of TRUE

À partir du résultat, nous pouvons voir que le modèle a été ajusté 5 fois en utilisant à chaque fois une taille d’échantillon de 8 observations.

À chaque fois, le modèle a ensuite été utilisé pour prédire les valeurs des 2 observations retenues et les métriques suivantes ont été calculées à chaque fois :

  • RMSE : erreur quadratique moyenne. Celui-ci mesure la différence moyenne entre les prédictions faites par le modèle et les observations réelles. Plus le RMSE est bas, plus un modèle peut prédire avec précision les observations réelles.
  • MAE : L’erreur absolue moyenne. Il s’agit de la différence absolue moyenne entre les prédictions faites par le modèle et les observations réelles. Plus le MAE est bas, plus un modèle peut prédire avec précision les observations réelles.

La moyenne des valeurs RMSE et MAE pour les cinq volets est affichée dans le résultat :

  • RMSE : 3,612302
  • MAE : 3.232153

Ces métriques nous donnent une idée des performances du modèle sur des données inédites.

En pratique, nous ajustons généralement plusieurs modèles différents et comparons ces métriques pour déterminer quel modèle fonctionne le mieux sur des données invisibles.

Par exemple, nous pourrions procéder à l’ajustement d’un modèle de régression polynomiale et y effectuer une validation croisée K-fold pour voir comment les métriques RMSE et MAE se comparent au modèle de régression linéaire multiple.

Remarque n°1 : Dans cet exemple, nous avons choisi d’utiliser k=5 plis, mais vous pouvez choisir le nombre de plis que vous souhaitez. En pratique, nous choisissons généralement entre 5 et 10 plis, car cela s’avère être le nombre optimal de plis qui produit des taux d’erreur de test fiables.

Note #2 : La fonction trainControl() accepte de nombreux arguments potentiels. Vous pouvez trouver la documentation complète de cette fonction ici .

Ressources additionnelles

Les didacticiels suivants fournissent des informations supplémentaires sur la formation de modèles :

Introduction à la validation croisée K-Fold
Introduction à la validation croisée Leave-One-Out
Qu’est-ce que le surapprentissage dans l’apprentissage automatique ?

Ajouter un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *