{ "cells": [ { "cell_type": "markdown", "id": "7c5778f2-c097-4f8c-8f54-298e86be7b0d", "metadata": {}, "source": [ "# Cross Validation and Grid Search" ] }, { "cell_type": "markdown", "id": "89733788-abd8-4355-9e7b-8a5aac66705a", "metadata": {}, "source": [ "Here are examples of cross validation (running the same model with different training sets to get statistics on the performance) and grid search (trying an exhaustive search in parameter space.)" ] }, { "cell_type": "code", "execution_count": 185, "id": "0ae0180f-97d9-4a4b-802c-fda6b7ea4196", "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import fetch_openml" ] }, { "cell_type": "code", "execution_count": 186, "id": "83c3e695-8210-473a-b177-dfd2be1adbf0", "metadata": {}, "outputs": [], "source": [ "# Try this -- it may get blocked by LCPS?\n", "\n", "# uncomment next line\n", "# mnist = fetch_openml('mnist_784', as_frame=False)" ] }, { "cell_type": "code", "execution_count": 188, "id": "e01b5321-3631-4e1b-8518-7b091a517ecd", "metadata": {}, "outputs": [], "source": [ "# if above doesn't work, do this.\n", "# first download and gunzip the file from github \"mnist.gzip\"\n", "\n", "# infile = open(\"mnist.pk\", \"rb\")\n", "# mnist = pickle.load(infile)" ] }, { "cell_type": "code", "execution_count": 189, "id": "7e1d470b-f7fb-4beb-9a7c-61df427953c8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((70000, 784), (70000,))" ] }, "execution_count": 189, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X,y = mnist.data, mnist.target\n", "X.shape, y.shape" ] }, { "cell_type": "code", "execution_count": 190, "id": "263d21b8-df82-4b4b-83a9-50e725db4adf", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "id": "6d802781-a83c-40a3-8381-e56b5863a1c0", "metadata": {}, "source": [ "Note we use a subset of the data!" ] }, { "cell_type": "code", "execution_count": 191, "id": "7c4cfa37-15b6-4278-a1e0-847551a86afb", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.01, train_size=0.1)" ] }, { "cell_type": "code", "execution_count": 192, "id": "d8465584-d318-4fe5-8df8-a5c713b0aa40", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((7000, 784), (700, 784))" ] }, "execution_count": 192, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.shape, X_test.shape" ] }, { "cell_type": "markdown", "id": "18a0c32a-571b-46d6-81c4-894123e08b0c", "metadata": {}, "source": [ "Let's make a binary classifier first : \"5\" or \"not 5\"" ] }, { "cell_type": "code", "execution_count": 193, "id": "2486752f-9094-41ca-ad87-3aca8f598dfc", "metadata": {}, "outputs": [], "source": [ "y_train_5 = (y_train == '5')\n", "y_test_5 = (y_test == '5')" ] }, { "cell_type": "markdown", "id": "f83d5a07-b844-4ceb-9339-d1f013bdb351", "metadata": {}, "source": [ "SGD (standard gradient descent) is a type of linear classifier that minimizes a loss function. We use it here just because it's different." ] }, { "cell_type": "code", "execution_count": 194, "id": "ae7c8455-eb83-48bc-a5fa-ae473ae433dc", "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import SGDClassifier\n", "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 195, "id": "81fd1fb9-2d04-4d57-a1d0-4c81b93924ef", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
SGDClassifier(random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
SGDClassifier(random_state=42)
KNeighborsClassifier(n_neighbors=7)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
KNeighborsClassifier(n_neighbors=7)
GridSearchCV(cv=3,\n", " estimator=Pipeline(steps=[('kneighborsclassifier',\n", " KNeighborsClassifier())]),\n", " n_jobs=1,\n", " param_grid={'kneighborsclassifier__n_neighbors': [1, 2, 3, 4, 5,\n", " 10, 15, 20, 25,\n", " 30]},\n", " scoring='accuracy')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
GridSearchCV(cv=3,\n", " estimator=Pipeline(steps=[('kneighborsclassifier',\n", " KNeighborsClassifier())]),\n", " n_jobs=1,\n", " param_grid={'kneighborsclassifier__n_neighbors': [1, 2, 3, 4, 5,\n", " 10, 15, 20, 25,\n", " 30]},\n", " scoring='accuracy')
Pipeline(steps=[('kneighborsclassifier', KNeighborsClassifier(n_neighbors=1))])
KNeighborsClassifier(n_neighbors=1)
\n", " | mean_fit_time | \n", "std_fit_time | \n", "mean_score_time | \n", "std_score_time | \n", "param_kneighborsclassifier__n_neighbors | \n", "params | \n", "split0_test_score | \n", "split1_test_score | \n", "split2_test_score | \n", "mean_test_score | \n", "std_test_score | \n", "rank_test_score | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.008094 | \n", "0.000404 | \n", "0.331692 | \n", "0.013044 | \n", "1 | \n", "{'kneighborsclassifier__n_neighbors': 1} | \n", "0.928877 | \n", "0.931847 | \n", "0.940420 | \n", "0.933715 | \n", "0.004894 | \n", "1 | \n", "
2 | \n", "0.007854 | \n", "0.000286 | \n", "0.441001 | \n", "0.090881 | \n", "3 | \n", "{'kneighborsclassifier__n_neighbors': 3} | \n", "0.931877 | \n", "0.923275 | \n", "0.933991 | \n", "0.929714 | \n", "0.004634 | \n", "2 | \n", "
4 | \n", "0.007764 | \n", "0.000423 | \n", "0.346311 | \n", "0.010852 | \n", "5 | \n", "{'kneighborsclassifier__n_neighbors': 5} | \n", "0.931877 | \n", "0.919417 | \n", "0.927990 | \n", "0.926428 | \n", "0.005205 | \n", "3 | \n", "
3 | \n", "0.009060 | \n", "0.001003 | \n", "0.409770 | \n", "0.098917 | \n", "4 | \n", "{'kneighborsclassifier__n_neighbors': 4} | \n", "0.931020 | \n", "0.917703 | \n", "0.928418 | \n", "0.925714 | \n", "0.005763 | \n", "4 | \n", "
5 | \n", "0.007390 | \n", "0.000223 | \n", "0.327227 | \n", "0.004433 | \n", "10 | \n", "{'kneighborsclassifier__n_neighbors': 10} | \n", "0.923308 | \n", "0.911273 | \n", "0.924989 | \n", "0.919857 | \n", "0.006108 | \n", "5 | \n", "
1 | \n", "0.008793 | \n", "0.001301 | \n", "0.333124 | \n", "0.016238 | \n", "2 | \n", "{'kneighborsclassifier__n_neighbors': 2} | \n", "0.912596 | \n", "0.915988 | \n", "0.922846 | \n", "0.917144 | \n", "0.004263 | \n", "6 | \n", "
6 | \n", "0.007868 | \n", "0.000141 | \n", "0.379707 | \n", "0.056234 | \n", "15 | \n", "{'kneighborsclassifier__n_neighbors': 15} | \n", "0.917738 | \n", "0.898843 | \n", "0.913416 | \n", "0.909999 | \n", "0.008084 | \n", "7 | \n", "
7 | \n", "0.007594 | \n", "0.000229 | \n", "0.359337 | \n", "0.039604 | \n", "20 | \n", "{'kneighborsclassifier__n_neighbors': 20} | \n", "0.912596 | \n", "0.899700 | \n", "0.906987 | \n", "0.906428 | \n", "0.005280 | \n", "8 | \n", "
8 | \n", "0.007759 | \n", "0.000049 | \n", "0.327107 | \n", "0.002920 | \n", "25 | \n", "{'kneighborsclassifier__n_neighbors': 25} | \n", "0.908740 | \n", "0.890699 | \n", "0.903129 | \n", "0.900856 | \n", "0.007539 | \n", "9 | \n", "
9 | \n", "0.007628 | \n", "0.000327 | \n", "0.327724 | \n", "0.004928 | \n", "30 | \n", "{'kneighborsclassifier__n_neighbors': 30} | \n", "0.898029 | \n", "0.883412 | \n", "0.896700 | \n", "0.892714 | \n", "0.006600 | \n", "10 | \n", "