from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
data_X, data_Y = mnist['data'], mnist['target']
from sklearn.model_selection import train_test_split
train_X, test_X, train_Y, test_Y = train_test_split(data_X, data_Y, test_size=0.2, random_state=42, stratify = data_Y)
train_Y_5 = (train_Y == 5)
test_Y_5 = (test_Y == 5)
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.linear_model import SGDClassifier
scores = cross_val_predict(SGDClassifier(random_state=42), train_X, train_Y_5, cv = 3, method="decision_function")
sgd_precisions, sgd_recalls, thresholds = precision_recall_curve(train_Y_5, scores)
sgd_fpr, sgd_tpr, thresholds = roc_curve(train_Y_5, scores)
from sklearn.ensemble import RandomForestClassifier
probs_forest = cross_val_predict(RandomForestClassifier(random_state=42), train_X, train_Y_5, cv=3, method='predict_proba')
rft_precisions, rft_recalls, thresholds = precision_recall_curve(train_Y_5, probs_forest[:, 1])
rft_fpr, rft_tpr, thresholds_forest = roc_curve(train_Y_5, probs_forest[:, 1])
# curve near to the top-right corner is better
fig, ax = plt.subplots();
line_1, = ax.plot(sgd_recalls, sgd_precisions, 'r-')
line_2, = ax.plot(rft_recalls, rft_precisions, 'b-')
ax.set_xlabel('Recall', size=18);
ax.set_ylabel('Precision', size = 18);
legend = ax.legend((line_1, line_2), ('SGD', 'RFT'), loc='lower right', shadow=True, facecolor='0.9');
# curve near to the top-left corner is better
fig, ax = plt.subplots();
line_1, = ax.plot(sgd_fpr, sgd_tpr, 'r-')
line_2, = ax.plot(rft_fpr, rft_tpr, 'b-')
line_3, = ax.plot([0, 1], [0, 1], 'k--')
ax.set_xlabel('False Positive Rate', size=18);
ax.set_ylabel('True Positive Rate', size = 18);
legend = ax.legend((line_1, line_2, line_3), ('SGD', 'RFT', '--'), loc='lower right', shadow=True, facecolor='0.9');
# Compare ROC AUC value, the higher value represent better performance
# SGD model
print(roc_auc_score(train_Y_5, scores))
# Random Forest Tree Model
print(roc_auc_score(train_Y_5, probs_forest[:, 1]))
final_model = RandomForestClassifier(random_state=42)
final_model.fit(train_X, train_Y_5)
test_pred = final_model.predict(test_X)
# precision
precision_score(test_Y_5, test_pred)
# recall
recall_score(test_Y_5, test_pred)
# f1
f1_score(test_Y_5, test_pred)