Binary Classes Classification

1. Get the Data

In [7]:
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
data_X, data_Y = mnist['data'], mnist['target']
/anaconda2/envs/Python3_R/lib/python3.7/site-packages/sklearn/utils/deprecation.py:85: DeprecationWarning: Function fetch_mldata is deprecated; fetch_mldata was deprecated in version 0.20 and will be removed in version 0.22. Please use fetch_openml.
  warnings.warn(msg, category=DeprecationWarning)
/anaconda2/envs/Python3_R/lib/python3.7/site-packages/sklearn/utils/deprecation.py:85: DeprecationWarning: Function mldata_filename is deprecated; mldata_filename was deprecated in version 0.20 and will be removed in version 0.22. Please use fetch_openml.
  warnings.warn(msg, category=DeprecationWarning)

2. Prepare Training Set and Test Set

In [8]:
from sklearn.model_selection import train_test_split
train_X, test_X, train_Y, test_Y = train_test_split(data_X, data_Y, test_size=0.2, random_state=42, stratify = data_Y)
In [9]:
train_Y_5 = (train_Y == 5)
test_Y_5 = (test_Y == 5)

3. Select and Train a Model

In [39]:
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_score, recall_score, f1_score

SGD Model

In [17]:
from sklearn.linear_model import SGDClassifier
scores = cross_val_predict(SGDClassifier(random_state=42), train_X, train_Y_5, cv = 3, method="decision_function")
In [25]:
sgd_precisions, sgd_recalls, thresholds = precision_recall_curve(train_Y_5, scores)
sgd_fpr, sgd_tpr, thresholds = roc_curve(train_Y_5, scores)

Random Forest Model

In [27]:
from sklearn.ensemble import RandomForestClassifier
probs_forest = cross_val_predict(RandomForestClassifier(random_state=42), train_X, train_Y_5, cv=3, method='predict_proba')
/anaconda2/envs/Python3_R/lib/python3.7/site-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.
  "10 in version 0.20 to 100 in 0.22.", FutureWarning)
/anaconda2/envs/Python3_R/lib/python3.7/site-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.
  "10 in version 0.20 to 100 in 0.22.", FutureWarning)
/anaconda2/envs/Python3_R/lib/python3.7/site-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.
  "10 in version 0.20 to 100 in 0.22.", FutureWarning)
In [31]:
rft_precisions, rft_recalls, thresholds = precision_recall_curve(train_Y_5, probs_forest[:, 1])
rft_fpr, rft_tpr, thresholds_forest = roc_curve(train_Y_5, probs_forest[:, 1])

Precision v.s. Recall

In [55]:
# curve near to the top-right corner is better
fig, ax = plt.subplots();

line_1, = ax.plot(sgd_recalls, sgd_precisions, 'r-')
line_2, = ax.plot(rft_recalls, rft_precisions, 'b-')

ax.set_xlabel('Recall', size=18);
ax.set_ylabel('Precision', size = 18);
legend = ax.legend((line_1, line_2), ('SGD', 'RFT'), loc='lower right', shadow=True, facecolor='0.9');

ROC curve

In [37]:
# curve near to the top-left corner is better
fig, ax = plt.subplots();

line_1, = ax.plot(sgd_fpr, sgd_tpr, 'r-')
line_2, = ax.plot(rft_fpr, rft_tpr, 'b-')
line_3, = ax.plot([0, 1], [0, 1], 'k--')

ax.set_xlabel('False Positive Rate', size=18);
ax.set_ylabel('True Positive Rate', size = 18);
legend = ax.legend((line_1, line_2, line_3), ('SGD', 'RFT', '--'), loc='lower right', shadow=True, facecolor='0.9');

Evaluate Models

In [44]:
# Compare ROC AUC value, the higher value represent better performance
# SGD model
print(roc_auc_score(train_Y_5, scores))
# Random Forest Tree Model
print(roc_auc_score(train_Y_5, probs_forest[:, 1]))
0.9689078366451512
0.9923809022977892

4. Train Final Model

In [47]:
final_model = RandomForestClassifier(random_state=42)
final_model.fit(train_X, train_Y_5)
/anaconda2/envs/Python3_R/lib/python3.7/site-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.
  "10 in version 0.20 to 100 in 0.22.", FutureWarning)
Out[47]:
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
                       max_depth=None, max_features='auto', max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=10,
                       n_jobs=None, oob_score=False, random_state=42, verbose=0,
                       warm_start=False)

5. Evaluate System on the Test Set

In [49]:
test_pred = final_model.predict(test_X)
In [54]:
# precision
precision_score(test_Y_5, test_pred)
# recall
recall_score(test_Y_5, test_pred)
# f1
f1_score(test_Y_5, test_pred)
Out[54]:
0.9143587558585429

Reference

  • Hands-on Machine Learning with Scikit-Learn & TensorFlow