Lasso-regressie in python (stap voor stap)


Lasso-regressie is een methode die we kunnen gebruiken om een regressiemodel te fitten wanneer multicollineariteit in de gegevens aanwezig is.

In een notendop probeert regressie met de kleinste kwadraten coëfficiëntschattingen te vinden die de resterende som van de kwadraten (RSS) minimaliseren:

RSS = Σ(y i – ŷ i )2

Goud:

  • Σ : Een Grieks symbool dat som betekent
  • y i : de werkelijke responswaarde voor de i-de waarneming
  • ŷ i : De voorspelde responswaarde op basis van het meervoudige lineaire regressiemodel

Omgekeerd probeert lasso-regressie het volgende te minimaliseren:

RSS + λΣ|β j |

waarbij j van 1 naar p voorspellende variabelen gaat en λ ≥ 0.

Deze tweede term in de vergelijking staat bekend als de opnameboete . Bij lasso-regressie selecteren we een waarde voor λ die de laagst mogelijke MSE-test (mean square error) oplevert.

Deze tutorial biedt een stapsgewijs voorbeeld van hoe u lasso-regressie in Python uitvoert.

Stap 1: Importeer de benodigde pakketten

Eerst zullen we de benodigde pakketten importeren om lasso-regressie in Python uit te voeren:

 import pandas as pd
from numpy import arange
from sklearn. linear_model import LassoCV
from sklearn. model_selection import RepeatedKFold

Stap 2: Gegevens laden

Voor dit voorbeeld gebruiken we een dataset met de naam mtcars , die informatie bevat over 33 verschillende auto’s. We zullen hp gebruiken als de responsvariabele en de volgende variabelen als voorspellers:

  • mpg
  • gewicht
  • shit
  • qsec

De volgende code laat zien hoe u deze gegevensset laadt en weergeeft:

 #define URL where data is located
url = "https://raw.githubusercontent.com/Statorials/Python-Guides/main/mtcars.csv"

#read in data
data_full = pd. read_csv (url)

#select subset of data
data = data_full[["mpg", "wt", "drat", "qsec", "hp"]]

#view first six rows of data
data[0:6]

	mpg wt drat qsec hp
0 21.0 2.620 3.90 16.46 110
1 21.0 2.875 3.90 17.02 110
2 22.8 2.320 3.85 18.61 93
3 21.4 3.215 3.08 19.44 110
4 18.7 3,440 3.15 17.02 175
5 18.1 3.460 2.76 20.22 105

Stap 3: Pas het Lasso-regressiemodel toe

Vervolgens zullen we de functie LassoCV() van sklearn gebruiken om het lasso-regressiemodel aan te passen en de functie RepeatedKFold() gebruiken om k-voudige kruisvalidatie uit te voeren om de optimale alfawaarde te vinden die voor de strafterm moet worden gebruikt.

Opmerking: in Python wordt de term ‚alpha‘ gebruikt in plaats van ‚lambda‘.

Voor dit voorbeeld kiezen we k = 10 vouwen en herhalen we het kruisvalidatieproces drie keer.

Houd er ook rekening mee dat LassoCV() standaard alleen alfawaarden 0,1, 1 en 10 test. We kunnen echter ons eigen alfabereik instellen van 0 tot 1 in stappen van 0,01:

 #define predictor and response variables
X = data[["mpg", "wt", "drat", "qsec"]]
y = data["hp"]

#define cross-validation method to evaluate model
cv = RepeatedKFold(n_splits= 10 , n_repeats= 3 , random_state= 1 )

#define model
model = LassoCV(alphas= arange (0, 1, 0.01), cv=cv, n_jobs= -1 )

#fit model
model. fit (x,y)

#display lambda that produced the lowest test MSE
print( model.alpha_ )

0.99

De lambdawaarde die de MSE van de test minimaliseert blijkt 0,99 te zijn.

Stap 4: Gebruik het model om voorspellingen te doen

Ten slotte kunnen we het laatste lasso-regressiemodel gebruiken om voorspellingen te doen over nieuwe waarnemingen. De volgende code laat bijvoorbeeld zien hoe u een nieuwe auto definieert met de volgende kenmerken:

  • mpg: 24
  • gewicht: 2,5
  • prijs: 3,5
  • qsec: 18,5

De volgende code laat zien hoe u het gepaste lasso-regressiemodel kunt gebruiken om de pk- waarde van deze nieuwe waarneming te voorspellen:

 #define new observation
new = [24, 2.5, 3.5, 18.5]

#predict hp value using lasso regression model
model. predict ([new])

array([105.63442071])

Op basis van de ingevoerde waarden voorspelt het model dat deze auto een pk- waarde zal hebben van 105,63442071 .

Je kunt de volledige Python-code die in dit voorbeeld wordt gebruikt hier vinden.

Einen Kommentar hinzufügen

Deine E-Mail-Adresse wird nicht veröffentlicht. Erforderliche Felder sind mit * markiert