Xgboost в r: покроковий приклад
Підвищення — це техніка машинного навчання, яка, як було показано, створює моделі з високою точністю прогнозування.
Одним із найпоширеніших способів реалізації посилення на практиці є використання XGBoost , що скорочується від «extreme gradient boosting».
Цей підручник надає покроковий приклад того, як використовувати XGBoost для адаптації розширеної моделі в R.
Крок 1: Завантажте необхідні пакети
Спочатку ми завантажимо необхідні бібліотеки.
library (xgboost) #for fitting the xgboost model library (caret) #for general data preparation and model fitting
Крок 2. Завантажте дані
Для цього прикладу ми підберемо вдосконалену модель регресії до Бостонського набору даних із пакету MASS .
Цей набір даних містить 13 змінних предикторів, які ми використовуватимемо для прогнозування змінної відповіді під назвою mdev , яка представляє середнє значення будинків у різних районах перепису навколо Бостона.
#load the data
data = MASS::Boston
#view the structure of the data
str(data)
'data.frame': 506 obs. of 14 variables:
$ crim: num 0.00632 0.02731 0.02729 0.03237 0.06905 ...
$ zn : num 18 0 0 0 0 0 12.5 12.5 12.5 12.5 ...
$ indus: num 2.31 7.07 7.07 2.18 2.18 2.18 7.87 7.87 7.87 7.87 ...
$chas: int 0 0 0 0 0 0 0 0 0 0 ...
$ nox: num 0.538 0.469 0.469 0.458 0.458 0.458 0.524 0.524 0.524 0.524 ...
$rm: num 6.58 6.42 7.18 7 7.15 ...
$ age: num 65.2 78.9 61.1 45.8 54.2 58.7 66.6 96.1 100 85.9 ...
$ dis: num 4.09 4.97 4.97 6.06 6.06 ...
$rad: int 1 2 2 3 3 3 5 5 5 5 ...
$ tax: num 296 242 242 222 222 222 311 311 311 311 ...
$ptratio: num 15.3 17.8 17.8 18.7 18.7 18.7 15.2 15.2 15.2 15.2 ...
$ black: num 397 397 393 395 397 ...
$ lstat: num 4.98 9.14 4.03 2.94 5.33 ...
$ medv: num 24 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 ...
Ми бачимо, що набір даних містить загалом 506 спостережень і 14 змінних.
Крок 3: Підготуйте дані
Далі ми скористаємося функцією createDataPartition() із пакета каретки, щоб розділити вихідний набір даних на набір для навчання та тестування.
Для цього прикладу ми вирішимо використовувати 80% вихідного набору даних як частину навчального набору.
Зверніть увагу, що пакет xgboost також використовує матричні дані, тому ми будемо використовувати функцію data.matrix() для зберігання наших змінних предиктора.
#make this example reproducible
set.seed(0)
#split into training (80%) and testing set (20%)
parts = createDataPartition(data$medv, p = .8 , list = F )
train = data[parts, ]
test = data[-parts, ]
#define predictor and response variables in training set
train_x = data. matrix (train[, -13])
train_y = train[,13]
#define predictor and response variables in testing set
test_x = data. matrix (test[, -13])
test_y = test[, 13]
#define final training and testing sets
xgb_train = xgb. DMatrix (data = train_x, label = train_y)
xgb_test = xgb. DMatrix (data = test_x, label = test_y)
Крок 4: Налаштуйте модель
Далі ми налаштуємо модель XGBoost за допомогою функції xgb.train() , яка відображає RMSE навчання та тестування (середня квадратична помилка) для кожного циклу підвищення.
Зауважте, що для цього прикладу ми вирішили використовувати 70 раундів, але для набагато більших наборів даних нерідко використовують сотні чи навіть тисячі раундів. Тільки майте на увазі, що чим більше раундів, тим довший час роботи.
Також зауважте, що аргумент max.degree визначає глибину розробки окремих дерев рішень. Зазвичай ми обираємо це число досить низьке, наприклад 2 або 3, щоб виростити менші дерева. Показано, що цей підхід дає більш точні моделі.
#define watchlist
watchlist = list(train=xgb_train, test=xgb_test)
#fit XGBoost model and display training and testing data at each round
model = xgb.train(data = xgb_train, max.depth = 3 , watchlist=watchlist, nrounds = 70 )
[1] train-rmse:10.167523 test-rmse:10.839775
[2] train-rmse:7.521903 test-rmse:8.329679
[3] train-rmse:5.702393 test-rmse:6.691415
[4] train-rmse:4.463687 test-rmse:5.631310
[5] train-rmse:3.666278 test-rmse:4.878750
[6] train-rmse:3.159799 test-rmse:4.485698
[7] train-rmse:2.855133 test-rmse:4.230533
[8] train-rmse:2.603367 test-rmse:4.099881
[9] train-rmse:2.445718 test-rmse:4.084360
[10] train-rmse:2.327318 test-rmse:3.993562
[11] train-rmse:2.267629 test-rmse:3.944454
[12] train-rmse:2.189527 test-rmse:3.930808
[13] train-rmse:2.119130 test-rmse:3.865036
[14] train-rmse:2.086450 test-rmse:3.875088
[15] train-rmse:2.038356 test-rmse:3.881442
[16] train-rmse:2.010995 test-rmse:3.883322
[17] train-rmse:1.949505 test-rmse:3.844382
[18] train-rmse:1.911711 test-rmse:3.809830
[19] train-rmse:1.888488 test-rmse:3.809830
[20] train-rmse:1.832443 test-rmse:3.758502
[21] train-rmse:1.816150 test-rmse:3.770216
[22] train-rmse:1.801369 test-rmse:3.770474
[23] train-rmse:1.788891 test-rmse:3.766608
[24] train-rmse:1.751795 test-rmse:3.749583
[25] train-rmse:1.713306 test-rmse:3.720173
[26] train-rmse:1.672227 test-rmse:3.675086
[27] train-rmse:1.648323 test-rmse:3.675977
[28] train-rmse:1.609927 test-rmse:3.745338
[29] train-rmse:1.594891 test-rmse:3.756049
[30] train-rmse:1.578573 test-rmse:3.760104
[31] train-rmse:1.559810 test-rmse:3.727940
[32] train-rmse:1.547852 test-rmse:3.731702
[33] train-rmse:1.534589 test-rmse:3.729761
[34] train-rmse:1.520566 test-rmse:3.742681
[35] train-rmse:1.495155 test-rmse:3.732993
[36] train-rmse:1.467939 test-rmse:3.738329
[37] train-rmse:1.446343 test-rmse:3.713748
[38] train-rmse:1.435368 test-rmse:3.709469
[39] train-rmse:1.401356 test-rmse:3.710637
[40] train-rmse:1.390318 test-rmse:3.709461
[41] train-rmse:1.372635 test-rmse:3.708049
[42] train-rmse:1.367977 test-rmse:3.707429
[43] train-rmse:1.359531 test-rmse:3.711663
[44] train-rmse:1.335347 test-rmse:3.709101
[45] train-rmse:1.331750 test-rmse:3.712490
[46] train-rmse:1.313087 test-rmse:3.722981
[47] train-rmse:1.284392 test-rmse:3.712840
[48] train-rmse:1.257714 test-rmse:3.697482
[49] train-rmse:1.248218 test-rmse:3.700167
[50] train-rmse:1.243377 test-rmse:3.697914
[51] train-rmse:1.231956 test-rmse:3.695797
[52] train-rmse:1.219341 test-rmse:3.696277
[53] train-rmse:1.207413 test-rmse:3.691465
[54] train-rmse:1.197197 test-rmse:3.692108
[55] train-rmse:1.171748 test-rmse:3.683577
[56] train-rmse:1.156332 test-rmse:3.674458
[57] train-rmse:1.147686 test-rmse:3.686367
[58] train-rmse:1.143572 test-rmse:3.686375
[59] train-rmse:1.129780 test-rmse:3.679791
[60] train-rmse:1.111257 test-rmse:3.679022
[61] train-rmse:1.093541 test-rmse:3.699670
[62] train-rmse:1.083934 test-rmse:3.708187
[63] train-rmse:1.067109 test-rmse:3.712538
[64] train-rmse:1.053887 test-rmse:3.722480
[65] train-rmse:1.042127 test-rmse:3.720720
[66] train-rmse:1.031617 test-rmse:3.721224
[67] train-rmse:1.016274 test-rmse:3.699549
[68] train-rmse:1.008184 test-rmse:3.709522
[69] train-rmse:0.999220 test-rmse:3.708000
[70] train-rmse:0.985907 test-rmse:3.705192
З результату ми бачимо, що мінімальний тестовий RMSE досягається при 56 раундах. За межами цієї точки тестовий RMSE починає збільшуватися, що вказує на те, що ми переналаштовуємо навчальні дані .
Отже, ми налаштуємо нашу остаточну модель XGBoost на використання 56 раундів:
#define final model
final = xgboost(data = xgb_train, max.depth = 3 , nrounds = 56 , verbose = 0 )
Примітка. Аргумент verbose=0 повідомляє R не відображати помилку навчання та тестування для кожного раунду.
Крок 5. Використовуйте модель для прогнозування
Нарешті, ми можемо використати остаточну покращену модель, щоб зробити прогнози щодо середньої вартості будинків у Бостоні в тестовому наборі.
Потім ми розрахуємо такі показники точності для моделі:
- MSE: середня квадратична помилка
- MAE: середня абсолютна похибка
- RMSE: середньоквадратична помилка
mean((test_y - pred_y)^2) #mse
caret::MAE(test_y, pred_y) #mae
caret::RMSE(test_y, pred_y) #rmse
[1] 13.50164
[1] 2.409426
[1] 3.674457
Середня квадратична помилка виявляється 3,674457 . Це являє собою середню різницю між прогнозом, зробленим для середніх значень будинку, і фактичними значеннями будинку, що спостерігаються в тестовому наборі.
Якщо ми хочемо, ми можемо порівняти цю RMSE з іншими моделями, такими як множинна лінійна регресія , гребенева регресія , регресія головної компоненти тощо. щоб побачити, яка модель дає найточніші прогнози.
Ви можете знайти повний код R, використаний у цьому прикладі , тут .