KNN Classification

Get the Data

In [1]:
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
/anaconda2/envs/Python3_R/lib/python3.7/site-packages/sklearn/utils/deprecation.py:85: DeprecationWarning: Function fetch_mldata is deprecated; fetch_mldata was deprecated in version 0.20 and will be removed in version 0.22. Please use fetch_openml.
  warnings.warn(msg, category=DeprecationWarning)
/anaconda2/envs/Python3_R/lib/python3.7/site-packages/sklearn/utils/deprecation.py:85: DeprecationWarning: Function mldata_filename is deprecated; mldata_filename was deprecated in version 0.20 and will be removed in version 0.22. Please use fetch_openml.
  warnings.warn(msg, category=DeprecationWarning)
In [2]:
data_X, data_Y = mnist['data'], mnist['target']

Prepare Training Set and Test Set

In [3]:
from sklearn.model_selection import train_test_split
train_X, test_X, train_Y, test_Y = train_test_split(data_X, data_Y, test_size=0.2, random_state=42, stratify = data_Y)

Grid Search to Fine Tune Hyperparameters

In [ ]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV

param_grid = [{'weights': ["uniform", "distance"], 'n_neighbors': [3, 4, 5]}]

knn_clf = KNeighborsClassifier()
grid_search = GridSearchCV(knn_clf, param_grid, cv=5, verbose=3, n_jobs=-1)
grid_search.fit(train_X, train_Y)
Fitting 5 folds for each of 6 candidates, totalling 30 fits
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.

Train Final Model

In [ ]:
final_model = grid_search.best_estimator_
final_model.fit(train_X, train_Y)

Evaluate System on the Test Set

In [ ]:
from sklearn.metrics import accuracy_score

y_pred = final_model.predict(test_X)
accuracy_score(test_Y, y_pred)