Analyse discriminante linéaire en Python (étape par étape)



L’analyse discriminante linéaire est une méthode que vous pouvez utiliser lorsque vous disposez d’un ensemble de variables prédictives et que vous souhaitez classer une variable de réponse en deux classes ou plus.

Ce didacticiel fournit un exemple étape par étape de la façon d’effectuer une analyse discriminante linéaire en Python.

Étape 1 : Charger les bibliothèques nécessaires

Tout d’abord, nous allons charger les fonctions et bibliothèques nécessaires pour cet exemple :

from sklearn.model_selection import train_test_split
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.model_selection import cross_val_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 
from sklearn import datasets
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

Étape 2 : Charger les données

Pour cet exemple, nous utiliserons l’ensemble de données iris de la bibliothèque sklearn. Le code suivant montre comment charger cet ensemble de données et le convertir en DataFrame pandas pour faciliter son utilisation :

#load iris dataset
iris = datasets.load_iris()

#convert dataset to pandas DataFrame
df = pd.DataFrame(data = np.c_[iris['data'], iris['target']],
                 columns = iris['feature_names'] + ['target'])
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
df.columns = ['s_length', 's_width', 'p_length', 'p_width', 'target', 'species']

#view first six rows of DataFrame
df.head()

   s_length  s_width  p_length  p_width  target species
0       5.1      3.5       1.4      0.2     0.0  setosa
1       4.9      3.0       1.4      0.2     0.0  setosa
2       4.7      3.2       1.3      0.2     0.0  setosa
3       4.6      3.1       1.5      0.2     0.0  setosa
4       5.0      3.6       1.4      0.2     0.0  setosa

#find how many total observations are in dataset
len(df.index)

150

Nous pouvons voir que l’ensemble de données contient 150 observations au total.

Pour cet exemple, nous allons construire un modèle d’analyse discriminante linéaire pour classer à quelles espèces appartient une fleur donnée.

Nous utiliserons les variables prédictives suivantes dans le modèle :

  • Longueur des sépales
  • Largeur des sépales
  • Longueur des pétales
  • Largeur des pétales

Et nous les utiliserons pour prédire la variable de réponse Species , qui prend en charge les trois classes potentielles suivantes :

  • sétosa
  • versicolor
  • virginie

Étape 3 : Ajuster le modèle LDA

Ensuite, nous adapterons le modèle LDA à nos données à l’aide de la fonction LinearDiscriminantAnalsyis de sklearn :

#define predictor and response variables
X = df[['s_length', 's_width', 'p_length', 'p_width']]
y = df['species']

#Fit the LDA model
model = LinearDiscriminantAnalysis()
model.fit(X, y)

Étape 4 : Utiliser le modèle pour faire des prédictions

Une fois que nous avons ajusté le modèle à l’aide de nos données, nous pouvons évaluer les performances du modèle en utilisant une validation croisée stratifiée répétée k-fold.

Pour cet exemple, nous utiliserons 10 plis et 3 répétitions :

#Define method to evaluate model
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)

#evaluate model
scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1)
print(np.mean(scores))   

0.9777777777777779

Nous pouvons voir que le modèle a réalisé une précision moyenne de 97,78 % .

Nous pouvons également utiliser le modèle pour prédire à quelle classe appartient une nouvelle fleur, en fonction des valeurs d’entrée :

#define new observation
new = [5, 3, 1, .4]

#predict which class the new observation belongs to
model.predict([new])

array(['setosa'], dtype='<U10')

On voit que le modèle prédit que cette nouvelle observation appartient à l’espèce appelée setosa .

Étape 5 : Visualisez les résultats

Enfin, nous pouvons créer un tracé LDA pour visualiser les discriminants linéaires du modèle et visualiser dans quelle mesure il sépare les trois espèces différentes dans notre ensemble de données :

#define data to plot
X = iris.data
y = iris.target
model = LinearDiscriminantAnalysis()
data_plot = model.fit(X, y).transform(X)
target_names = iris.target_names

#create LDA plot
plt.figure()
colors = ['red', 'green', 'blue']
lw = 2
for color, i, target_name in zip(colors, [0, 1, 2], target_names):
    plt.scatter(data_plot[y == i, 0], data_plot[y == i, 1], alpha=.8, color=color,
                label=target_name)

#add legend to plot
plt.legend(loc='best', shadow=False, scatterpoints=1)

#display LDA plot
plt.show()

analyse discriminante linéaire en Python

Vous pouvez trouver le code Python complet utilisé dans ce tutoriel ici .

Ajouter un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *