Les k plus proches voisins

L’algorithme des k plus proches voisins (en anglais : k nearest neighborsknn) appartient à la famille des algorithmes d’apprentissage automatique (machine learning).

Un peu d'histoire ...

L’idée d’apprentissage automatique ne date pas d’hier, puisque le terme de machine learning a été utilisé pour la première fois par l’informaticien américain Arthur Samuel en 1959. Les algorithmes d’apprentissage automatique ont connu un fort regain d’intérêt au début des années 2000 notamment grâce à la grande quantité de données disponibles sur Internet (on parle de « big data« ).

L’algorithme des k plus proches voisins est un algorithme d’apprentissage supervisé : il est nécessaire d’avoir des données labellisées. À partir d’un ensemble de données labellisées, il sera possible de classer (déterminer le label) d’une nouvelle donnée.

De nombreuses sociétés (par exemple les GAFAM) utilisent les données concernant leurs utilisateurs afin de ”nourrir” des algorithmes de machine learning qui permettront à ces sociétés d’en savoir toujours plus sur chaque utilisateur et ainsi de mieux cerner ses ”besoins” en termes de consommation.

Principe de l’algorithme

L’algorithme des k plus proches voisins ne nécessite pas de phase d’apprentissage à proprement parler, il faut juste stocker le jeu de données d’apprentissage.

Soit un ensemble \(E\) contenant \(n\) données labellisées : \(E=\left\{(y_i,\vec x_i)\right\}\) avec \(i\) compris entre \(1\) et \(n\), où \(y_i\) correspond à la classe (le label) de la donnée \(i\) et où le vecteur \(\vec x_i\) de dimension \(p\) représente les variables prédictrices de la donnée \(i\).

\(\vec x_i=(x_{1i}, x_{2i}, …, x_{pi})\)

Soit une donnée \(u\) qui n’appartient pas à \(E\) et qui ne possède pas de label (\(u\) est uniquement caractérisée par un vecteur \(\vec x_u\) de dimension \(p\)).

Soit une fonction \(d\) qui renvoie la distance entre la donnée \(u\) et une donnée quelconque appartenant à \(E\).

Soit un entier \(k\) inférieur ou égal à \(n\).

Voici le principe de l’algorithme des k plus proches voisins :

  • On calcule les distances entre la donnée \(u\) et chaque donnée appartenant à \(E\) à l’aide de la fonction \(d\)
  • On retient les \(k\) données du jeu de données \(E\) les plus proches de \(u\)
  • On attribue à \(u\) la classe qui est la plus fréquente parmi les \(k\) données les plus proches.

Il est possible d’utiliser différents types de distance : euclidienne, Hamming (distance de recouvrement, pour des textes), Manhattan, …

Activité
  • Dans l’exemple donné plus haut, déterminer les classes des \(k\) plus proches voisins de la donnée à classer (\(k\) variant de 1 à 5). Quelle classe peut-on attribuer à la donnée ?

 

Exemple en dimension 2

  • Choisir le nombre de classes,
  • Placer la donnée \(u\),
  • Afficher les \(k\) plus proches voisins,
  • Cliquer sur pour redistribuer les données.

 


Étude d’un exemple

Les données

Nous avons choisi ici de nous baser sur le jeu de données iris de Fisher. En 1936, Edgar Anderson a collecté des données sur 3 espèces d’iris : « iris setosa« , « iris virginica » et « iris versicolor » :

iris setosa

iris virginica

iris versicolor

Ce jeu de données est composé de 150 entrées, pour chaque entrée nous avons :

  • la longueur des sépales (en cm)
  • la largeur des sépales (en cm)
  • la longueur des pétales (en cm)
  • la largeur des pétales (en cm)
  • l’espèce d’iris : Iris setosa, Iris virginica ou Iris versicolor → label du jeu de données

Télécharger les données au format csv : iris.csv.

Activité
  • Si la classe de chaque donnée est définie par le champ « espèce » (species), quelle est la dimension de ce jeu de données ?

 

Bibliothèques Python utilisées

Nous allons utiliser 3 bibliothèques Python :

  • pandas  qui va nous permettre d’importer les données issues du fichier csv (voir Traitement des données en table)
  • matplotlib  qui va nous permettre de visualiser les données (tracer des graphiques)
  • Scikit-learn qui propose une implémentation de l’algorithme des k plus proches voisins.

 

Première visualisation des données

Une fois le fichier csv modifié, il est possible d’écrire un programme permettant de visualiser les données sous forme de graphique (abscisse : ”petal_length” , ordonnée : ”petal_width” ) :

import pandas
import matplotlib.pyplot as plt

# Importation des données
iris = pandas.read_csv("iris.csv")
x = iris.loc[:,"petal_length"]
y = iris.loc[:,"petal_width"]
lab = iris.loc[:,"species"]

# Affichage des données
for e, c in [('setosa','g'), ('virginica', 'r'), ('versicolor', 'b')]:
    plt.scatter(x[lab == e], y[lab == e], color = c, label = e)
plt.legend()
fig = plt.gcf()
fig.canvas.set_window_title('k plus proches voisins')
plt.show()
Explications

L’importation des données avec la bibliothèque pandas  est largement décrite dans l’article « Traitement des données en table ».

Ici x  et y  sont des séries (type pandas.Series), et contiennent les longueur et largeur des pétales, lab  contient les labels, c’est à dire les noms des espèces.

Les expressions de type x[lab == e]  permettent de filtrer les séries selon le nom e  de chaque espèce.

 

L’affichage des données est réalisé avec le module pyplot de la bibliothèque matplotlib :

la fonction plt.scatter(X, Y)  permet de tracer une série de points de coordonnées x[i], y[i]  (pour tout indice i  des listes X  et Y). Une figure  est automatiquement créée : la ligne fig = plt.gcf()  permet de la récupérer dans une variable fig , pour accéder à ses attributs. La ligne fig.canvas.set_window_title()  permet de changer le titre de la fenêtre dans laquelle la figure va apparaitre.

La commande plt.show()  affiche la figure dans une fenêtre avec barre d’outils et zone graphique.

On peut remarquer que les points (longueur, largeur) sont regroupés en « nuages » correspondant chacun à une espèce différente.

Le nuage « setosa » est isolé tandis que les deux « nuages » « virginica » et « versicolor » ont un peu tendance à se mélanger !

 

Identification de la classe d’une donnée

Supposons que nous ayons trouvé un iris et que, n’étant pas spécialiste, nous souhaitions en déterminer l’espèce.

On mesure la longueur et la largeur des pétales de cet iris, et on place le point (donnée à labéliser) sur la figure :

  • longueur : 5,29 cm
  • largeur : 1,27 cm

Premier cas de figure : le point apparait proche d’un « nuage »

L’identification de l’espèce ne pose pas de problème.

Deuxième cas de figure : le point apparait à proximité de plusieurs « nuages »

Dans ce genre de cas, il peut être intéressant d’utiliser l’algorithme des « k plus proches voisins » :

  • on calcule la distance entre le nouveau point et chaque point issu du jeu de données
  • on sélectionne uniquement les k distances les plus petites (les k plus proches voisins)
  • parmi les k plus proches voisins, on détermine quelle est l’espèce majoritaire.

On associe à notre « iris mystère » cette « espèce majoritaire parmi les k plus proches voisins »

Activité
  • Déterminer graphiquement les 3 plus proches voisins de l’iris « mystère » selon le critère de la taille des pétales ?
    (imprimer la figure ou bien faire un copier-coller dans un logiciel de dessin).

 

Utilisation de la bibliothèque scikit-learn

Voici un programme simple permettant la recherche des plus proches voisins à l’aide d’un algorithme optimisé :

import pandas
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier

# Importation des données
iris = pandas.read_csv("iris.csv")
x = iris.loc[:,"petal_length"]
y = iris.loc[:,"petal_width"]
lab = iris.loc[:,"species"]

# Affichage des données
for e, c in [('setosa','g'), ('virginica', 'r'), ('versicolor', 'b')]:
    plt.scatter(x[lab == e], y[lab == e], color = c, label = e)
plt.legend()
fig = plt.gcf()
fig.canvas.set_window_title('k plus proches voisins')


# Caractéristiques de l'Iris mystère
longueur = 2.5
largeur = 0.75
plt.scatter(longueur, largeur, color='k', marker="x")


# Algorithme des k plus proches voisins
k = 3
d = list(zip(x,y)) # Regroupement en couples de coordonnées (x, y)
model = KNeighborsClassifier(n_neighbors = k)
model.fit(d, lab)
prediction = model.predict([[longueur,largeur]])


# Affichage des résultats
plt.text(3,0.1, "Résultat : " + prediction[0], fontsize=12)


plt.show()
Explications

La première ligne d = list(zip(x,y))  permet de passer des 2 listes x  et y  en une unique liste de tuples.

KNeighborsClassifier  est une méthode issue de la bibliothèque scikit-learn  (from sklearn.neighbors import KNeighborsClassifier ), cette méthode prend ici en paramètre le nombre de « plus proches voisins » (model = KNeighborsClassifier(n_neighbors=k) )

model.fit(d, lab)  permet d’associer les tuples présents dans la liste d  avec les labels (« setosa » , « virginica »  ou « versicolor ») de la liste lab .

La ligne prediction= model.predict([[longueur,largeur]])  permet d’effectuer une prédiction pour un couple [longueur, largeur]. La variable prediction  contient alors le label trouvé par l’algorithme knn. Attention, prediction  est une liste Python qui contient un seul élément (le label), il est donc nécessaire d’écrire prediction[0]  afin d’obtenir le label.

Activité
  • Appliquer cet algorithme pour retrouver les résultats obtenus graphiquement.

 


Réalisation d’un algorithme avec Python

L’algorithme des k plus proches voisins s’appuie sur un calcul de distance.

Écrire une fonction distance(p1, p2)  qui calcule la distance euclidienne entre deux points p1  et p2  donnés sous forme de tuple (x, y) .

Réaliser l’algorithme de recherche des k plus proches voisins sous la forme d’une fonction knn(P, X, Y, labels, k) , avec :

  • P  : tuple des coordonnées du point à classer
  • X, Y, labels  : séries des données labellisées
  • k  : rang du dernier plus proche voisin

La fonction doit retourner une liste de k  tuples (point, classe, distance) .

 

sources : https://pixees.fr/informatiquelycee/n_site/nsi_prem_knn.html
https://cache.media.eduscol.education.fr/file/NSI/76/6/RA_Lycee_G_NSI_algo_knn_1170766.pdf

Vous aimerez aussi...

Laisser un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *