Xgboost w r: przykład krok po kroku


Boosting to technika uczenia maszynowego, która, jak wykazano, pozwala na tworzenie modeli o dużej dokładności predykcyjnej.

Jednym z najczęstszych sposobów wdrażania wzmocnienia w praktyce jest użycie XGBoost , skrótu od „ekstremalnego wzmocnienia gradientu”.

Ten samouczek zawiera przykład krok po kroku użycia XGBoost w celu dopasowania ulepszonego modelu w języku R.

Krok 1: Załaduj niezbędne pakiety

Najpierw załadujemy niezbędne biblioteki.

 library (xgboost) #for fitting the xgboost model
library (caret) #for general data preparation and model fitting

Krok 2: Załaduj dane

W tym przykładzie dopasujemy ulepszony model regresji do zbioru danych Boston z pakietu MASS .

Ten zbiór danych zawiera 13 zmiennych predykcyjnych, których użyjemy do przewidzenia zmiennej odpowiedzi zwanej mdev , która reprezentuje średnią wartość domów w różnych obwodach spisowych wokół Bostonu.

 #load the data
data = MASS::Boston

#view the structure of the data
str(data) 

'data.frame': 506 obs. of 14 variables:
 $ crim: num 0.00632 0.02731 0.02729 0.03237 0.06905 ...
 $ zn : num 18 0 0 0 0 0 12.5 12.5 12.5 12.5 ...
 $ indus: num 2.31 7.07 7.07 2.18 2.18 2.18 7.87 7.87 7.87 7.87 ...
 $chas: int 0 0 0 0 0 0 0 0 0 0 ...
 $ nox: num 0.538 0.469 0.469 0.458 0.458 0.458 0.524 0.524 0.524 0.524 ...
 $rm: num 6.58 6.42 7.18 7 7.15 ...
 $ age: num 65.2 78.9 61.1 45.8 54.2 58.7 66.6 96.1 100 85.9 ...
 $ dis: num 4.09 4.97 4.97 6.06 6.06 ...
 $rad: int 1 2 2 3 3 3 5 5 5 5 ...
 $ tax: num 296 242 242 222 222 222 311 311 311 311 ...
 $ptratio: num 15.3 17.8 17.8 18.7 18.7 18.7 15.2 15.2 15.2 15.2 ...
 $ black: num 397 397 393 395 397 ...
 $ lstat: num 4.98 9.14 4.03 2.94 5.33 ...
 $ medv: num 24 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 ...

Widzimy, że zbiór danych zawiera łącznie 506 obserwacji i 14 zmiennych.

Krok 3: Przygotuj dane

Następnie użyjemy funkcji createDataPartition() z pakietu caret, aby podzielić oryginalny zbiór danych na zbiór uczący i testowy.

W tym przykładzie zdecydujemy się użyć 80% oryginalnego zbioru danych jako części zbioru szkoleniowego.

Należy pamiętać, że pakiet xgboost również wykorzystuje dane macierzowe, więc użyjemy funkcji data.matrix() do przechowywania naszych zmiennych predykcyjnych.

 #make this example reproducible
set.seed(0)

#split into training (80%) and testing set (20%)
parts = createDataPartition(data$medv, p = .8 , list = F )
train = data[parts, ]
test = data[-parts, ]

#define predictor and response variables in training set
train_x = data. matrix (train[, -13])
train_y = train[,13]

#define predictor and response variables in testing set
test_x = data. matrix (test[, -13])
test_y = test[, 13]

#define final training and testing sets
xgb_train = xgb. DMatrix (data = train_x, label = train_y)
xgb_test = xgb. DMatrix (data = test_x, label = test_y)

Krok 4: Dostosuj model

Następnie dostroimy model XGBoost za pomocą funkcji xgb.train() , która wyświetla RMSE trenowania i testowania (średni błąd kwadratowy) dla każdego cyklu wzmacniania.

Należy pamiętać, że w tym przykładzie zdecydowaliśmy się użyć 70 rund, ale w przypadku znacznie większych zbiorów danych nierzadko używa się setek, a nawet tysięcy rund. Pamiętaj tylko, że im więcej rund, tym dłuższy czas działania.

Należy również pamiętać, że argument max. Degree określa głębokość rozwoju poszczególnych drzew decyzyjnych. Zwykle wybieramy tę liczbę dość niską, np. 2 lub 3, aby hodować mniejsze drzewa. Wykazano, że takie podejście pozwala uzyskać dokładniejsze modele.

 #define watchlist
watchlist = list(train=xgb_train, test=xgb_test)

#fit XGBoost model and display training and testing data at each round
model = xgb.train(data = xgb_train, max.depth = 3 , watchlist=watchlist, nrounds = 70 )

[1] train-rmse:10.167523 test-rmse:10.839775 
[2] train-rmse:7.521903 test-rmse:8.329679 
[3] train-rmse:5.702393 test-rmse:6.691415 
[4] train-rmse:4.463687 test-rmse:5.631310 
[5] train-rmse:3.666278 test-rmse:4.878750 
[6] train-rmse:3.159799 test-rmse:4.485698 
[7] train-rmse:2.855133 test-rmse:4.230533 
[8] train-rmse:2.603367 test-rmse:4.099881 
[9] train-rmse:2.445718 test-rmse:4.084360 
[10] train-rmse:2.327318 test-rmse:3.993562 
[11] train-rmse:2.267629 test-rmse:3.944454 
[12] train-rmse:2.189527 test-rmse:3.930808 
[13] train-rmse:2.119130 test-rmse:3.865036 
[14] train-rmse:2.086450 test-rmse:3.875088 
[15] train-rmse:2.038356 test-rmse:3.881442 
[16] train-rmse:2.010995 test-rmse:3.883322 
[17] train-rmse:1.949505 test-rmse:3.844382 
[18] train-rmse:1.911711 test-rmse:3.809830 
[19] train-rmse:1.888488 test-rmse:3.809830 
[20] train-rmse:1.832443 test-rmse:3.758502 
[21] train-rmse:1.816150 test-rmse:3.770216 
[22] train-rmse:1.801369 test-rmse:3.770474 
[23] train-rmse:1.788891 test-rmse:3.766608 
[24] train-rmse:1.751795 test-rmse:3.749583 
[25] train-rmse:1.713306 test-rmse:3.720173 
[26] train-rmse:1.672227 test-rmse:3.675086 
[27] train-rmse:1.648323 test-rmse:3.675977 
[28] train-rmse:1.609927 test-rmse:3.745338 
[29] train-rmse:1.594891 test-rmse:3.756049 
[30] train-rmse:1.578573 test-rmse:3.760104 
[31] train-rmse:1.559810 test-rmse:3.727940 
[32] train-rmse:1.547852 test-rmse:3.731702 
[33] train-rmse:1.534589 test-rmse:3.729761 
[34] train-rmse:1.520566 test-rmse:3.742681 
[35] train-rmse:1.495155 test-rmse:3.732993 
[36] train-rmse:1.467939 test-rmse:3.738329 
[37] train-rmse:1.446343 test-rmse:3.713748 
[38] train-rmse:1.435368 test-rmse:3.709469 
[39] train-rmse:1.401356 test-rmse:3.710637 
[40] train-rmse:1.390318 test-rmse:3.709461 
[41] train-rmse:1.372635 test-rmse:3.708049 
[42] train-rmse:1.367977 test-rmse:3.707429 
[43] train-rmse:1.359531 test-rmse:3.711663 
[44] train-rmse:1.335347 test-rmse:3.709101 
[45] train-rmse:1.331750 test-rmse:3.712490 
[46] train-rmse:1.313087 test-rmse:3.722981 
[47] train-rmse:1.284392 test-rmse:3.712840 
[48] train-rmse:1.257714 test-rmse:3.697482 
[49] train-rmse:1.248218 test-rmse:3.700167 
[50] train-rmse:1.243377 test-rmse:3.697914 
[51] train-rmse:1.231956 test-rmse:3.695797 
[52] train-rmse:1.219341 test-rmse:3.696277 
[53] train-rmse:1.207413 test-rmse:3.691465 
[54] train-rmse:1.197197 test-rmse:3.692108 
[55] train-rmse:1.171748 test-rmse:3.683577 
[56] train-rmse:1.156332 test-rmse:3.674458 
[57] train-rmse:1.147686 test-rmse:3.686367 
[58] train-rmse:1.143572 test-rmse:3.686375 
[59] train-rmse:1.129780 test-rmse:3.679791 
[60] train-rmse:1.111257 test-rmse:3.679022 
[61] train-rmse:1.093541 test-rmse:3.699670 
[62] train-rmse:1.083934 test-rmse:3.708187 
[63] train-rmse:1.067109 test-rmse:3.712538 
[64] train-rmse:1.053887 test-rmse:3.722480 
[65] train-rmse:1.042127 test-rmse:3.720720 
[66] train-rmse:1.031617 test-rmse:3.721224 
[67] train-rmse:1.016274 test-rmse:3.699549 
[68] train-rmse:1.008184 test-rmse:3.709522 
[69] train-rmse:0.999220 test-rmse:3.708000 
[70] train-rmse:0.985907 test-rmse:3.705192 

Z wyniku widzimy, że minimalny test RMSE osiąga się po 56 rundach. Powyżej tego punktu test RMSE zaczyna rosnąć, co wskazuje, że nadmiernie dopasowujemy dane uczące .

Zatem ustawimy nasz ostateczny model XGBoost na 56 rund:

 #define final model
final = xgboost(data = xgb_train, max.depth = 3 , nrounds = 56 , verbose = 0 )

Uwaga: Argument verbose=0 mówi R, aby nie wyświetlał błędu uczenia i testowania dla każdej rundy.

Krok 5: Użyj modelu do przewidywania

Na koniec możemy użyć ostatecznie ulepszonego modelu do przewidywania średniej wartości domów w Bostonie w zestawie testowym.

Następnie obliczymy następujące metryki dokładności modelu:

  • MSE: błąd średniokwadratowy
  • MAE: średni błąd bezwzględny
  • RMSE: średni błąd kwadratowy
 mean((test_y - pred_y)^2) #mse
caret::MAE(test_y, pred_y) #mae
caret::RMSE(test_y, pred_y) #rmse

[1] 13.50164
[1] 2.409426
[1] 3.674457

Średni błąd kwadratowy wynosi 3,674457 . Stanowi to średnią różnicę między prognozą wykonaną dla mediany wartości domów a rzeczywistymi wartościami domów zaobserwowanymi w zestawie testowym.

Jeśli chcemy, możemy porównać ten RMSE z innymi modelami, takimi jak wielokrotna regresja liniowa , regresja grzbietowa , regresja głównych składowych itp. aby zobaczyć, który model daje najdokładniejsze przewidywania.

Pełny kod R użyty w tym przykładzie znajdziesz tutaj .

Dodaj komentarz

Twój adres e-mail nie zostanie opublikowany. Wymagane pola są oznaczone *