{ "cells": [ { "cell_type": "markdown", "id": "7c5778f2-c097-4f8c-8f54-298e86be7b0d", "metadata": {}, "source": [ "# Cross Validation and Grid Search" ] }, { "cell_type": "markdown", "id": "89733788-abd8-4355-9e7b-8a5aac66705a", "metadata": {}, "source": [ "Here are examples of cross validation (running the same model with different training sets to get statistics on the performance) and grid search (trying an exhaustive search in parameter space.)" ] }, { "cell_type": "code", "execution_count": 185, "id": "0ae0180f-97d9-4a4b-802c-fda6b7ea4196", "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import fetch_openml" ] }, { "cell_type": "code", "execution_count": 186, "id": "83c3e695-8210-473a-b177-dfd2be1adbf0", "metadata": {}, "outputs": [], "source": [ "# Try this -- it may get blocked by LCPS?\n", "\n", "# uncomment next line\n", "# mnist = fetch_openml('mnist_784', as_frame=False)" ] }, { "cell_type": "code", "execution_count": 188, "id": "e01b5321-3631-4e1b-8518-7b091a517ecd", "metadata": {}, "outputs": [], "source": [ "# if above doesn't work, do this.\n", "# first download and gunzip the file from github \"mnist.gzip\"\n", "\n", "# infile = open(\"mnist.pk\", \"rb\")\n", "# mnist = pickle.load(infile)" ] }, { "cell_type": "code", "execution_count": 189, "id": "7e1d470b-f7fb-4beb-9a7c-61df427953c8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((70000, 784), (70000,))" ] }, "execution_count": 189, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X,y = mnist.data, mnist.target\n", "X.shape, y.shape" ] }, { "cell_type": "code", "execution_count": 190, "id": "263d21b8-df82-4b4b-83a9-50e725db4adf", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "id": "6d802781-a83c-40a3-8381-e56b5863a1c0", "metadata": {}, "source": [ "Note we use a subset of the data!" ] }, { "cell_type": "code", "execution_count": 191, "id": "7c4cfa37-15b6-4278-a1e0-847551a86afb", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.01, train_size=0.1)" ] }, { "cell_type": "code", "execution_count": 192, "id": "d8465584-d318-4fe5-8df8-a5c713b0aa40", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((7000, 784), (700, 784))" ] }, "execution_count": 192, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.shape, X_test.shape" ] }, { "cell_type": "markdown", "id": "18a0c32a-571b-46d6-81c4-894123e08b0c", "metadata": {}, "source": [ "Let's make a binary classifier first : \"5\" or \"not 5\"" ] }, { "cell_type": "code", "execution_count": 193, "id": "2486752f-9094-41ca-ad87-3aca8f598dfc", "metadata": {}, "outputs": [], "source": [ "y_train_5 = (y_train == '5')\n", "y_test_5 = (y_test == '5')" ] }, { "cell_type": "markdown", "id": "f83d5a07-b844-4ceb-9339-d1f013bdb351", "metadata": {}, "source": [ "SGD (standard gradient descent) is a type of linear classifier that minimizes a loss function. We use it here just because it's different." ] }, { "cell_type": "code", "execution_count": 194, "id": "ae7c8455-eb83-48bc-a5fa-ae473ae433dc", "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import SGDClassifier\n", "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 195, "id": "81fd1fb9-2d04-4d57-a1d0-4c81b93924ef", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
SGDClassifier(random_state=42)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "SGDClassifier(random_state=42)" ] }, "execution_count": 195, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sgd_classifier = SGDClassifier(random_state = 42)\n", "sgd_classifier.fit(X_train, y_train_5)" ] }, { "cell_type": "code", "execution_count": 196, "id": "ea8a3a6b-1df4-4edd-977a-4de04df7acde", "metadata": {}, "outputs": [], "source": [ "y_pred = sgd_classifier.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 197, "id": "823743c5-c7bd-4b8c-93aa-c9832d284f88", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9471428571428572" ] }, "execution_count": 197, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_score(y_pred, y_test_5)" ] }, { "cell_type": "markdown", "id": "dbe8dc37-2a54-4fab-81e2-2dd1bd0d3ca1", "metadata": {}, "source": [ "Now we try cross validation -- it runs the same model training with 5 different subsets of the training set. It validates it on the remainder. This is called k-fold validation. (you can change the 5)" ] }, { "cell_type": "code", "execution_count": 198, "id": "0810a078-9676-4702-9db9-0b9713c52b8d", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import cross_val_score" ] }, { "cell_type": "code", "execution_count": 199, "id": "a62f04f3-72f6-4208-b5f2-1866c01e523e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.91714286, 0.95142857, 0.95642857, 0.95785714, 0.95214286])" ] }, "execution_count": 199, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_val_score(sgd_classifier, X_train, y_train_5, cv=5, scoring=\"accuracy\")" ] }, { "cell_type": "markdown", "id": "ec4a4112-74b1-4588-9dc2-65c9a777d5f2", "metadata": {}, "source": [ "As an example we do a KNN also" ] }, { "cell_type": "code", "execution_count": 200, "id": "9d350183-8461-4dfd-8f96-907376c7af8e", "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier" ] }, { "cell_type": "code", "execution_count": 201, "id": "5cacd61f-14e2-4927-afbf-56d44cdcc696", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
KNeighborsClassifier(n_neighbors=7)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "KNeighborsClassifier(n_neighbors=7)" ] }, "execution_count": 201, "metadata": {}, "output_type": "execute_result" } ], "source": [ "knn_classifier = KNeighborsClassifier(n_neighbors=7)\n", "knn_classifier.fit(X_train, y_train_5)" ] }, { "cell_type": "code", "execution_count": 202, "id": "775924bf-08f2-4f19-907e-9cb10fae4cdc", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.98714286, 0.98785714, 0.98071429, 0.98428571, 0.98928571])" ] }, "execution_count": 202, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_val_score(knn_classifier, X_train, y_train_5, cv=5, scoring=\"accuracy\")" ] }, { "cell_type": "markdown", "id": "d8af1519-10c9-4074-9038-7f1772015173", "metadata": {}, "source": [ "KNN can do multiclass, so let's throw the 10-class original training set at it." ] }, { "cell_type": "code", "execution_count": 203, "id": "793acf05-05b4-410e-8b2e-cdf065ca0702", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.92785714, 0.94214286, 0.915 , 0.92857143, 0.92857143])" ] }, "execution_count": 203, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_val_score(knn_classifier, X_train, y_train, cv=5, scoring=\"accuracy\")" ] }, { "cell_type": "markdown", "id": "61ac6cb6-f31f-4106-a5b7-e61970714ca4", "metadata": {}, "source": [ "Let's look at some statistics. We'll run it 20 times and use a pandas dataset as a quick way to get stats." ] }, { "cell_type": "code", "execution_count": 204, "id": "d177efb8-8f70-41cb-b8e1-4c6cb5bf7788", "metadata": {}, "outputs": [], "source": [ "knn_results = cross_val_score(knn_classifier, X_train, y_train, cv=20, scoring=\"accuracy\")" ] }, { "cell_type": "code", "execution_count": 205, "id": "00469e09-d012-40d0-9dca-0c7326ee651d", "metadata": {}, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 206, "id": "d3db562a-f726-4973-912e-dc522dbcb88b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "count 20.000000\n", "mean 0.932714\n", "std 0.010466\n", "min 0.914286\n", "25% 0.927857\n", "50% 0.931429\n", "75% 0.937143\n", "max 0.954286\n", "dtype: float64" ] }, "execution_count": 206, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.Series(knn_results).describe()" ] }, { "cell_type": "markdown", "id": "1ba595e6-87a8-4aa5-8d4a-5ebee6f7b3a1", "metadata": {}, "source": [ "## Grid Search" ] }, { "cell_type": "markdown", "id": "2cd477ee-3852-46e8-949c-2bcefeb6ea90", "metadata": {}, "source": [ "A grid search takes a list of parameters and does a cross-validation on each parameter. It is a slow, exhaustive way to find the best setting for your model." ] }, { "cell_type": "code", "execution_count": 209, "id": "4e3869b5-cb32-45c1-a0bb-ea18413cc06b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GridSearchCV(cv=3,\n", " estimator=Pipeline(steps=[('kneighborsclassifier',\n", " KNeighborsClassifier())]),\n", " n_jobs=1,\n", " param_grid={'kneighborsclassifier__n_neighbors': [1, 2, 3, 4, 5,\n", " 10, 15, 20, 25,\n", " 30]},\n", " scoring='accuracy')\n" ] } ], "source": [ "from sklearn.pipeline import make_pipeline\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.model_selection import GridSearchCV\n", "\n", "# Create pipeline\n", "pipeline = make_pipeline(KNeighborsClassifier())\n", "\n", "# Define parameter grid\n", "param_grid = {\n", " 'kneighborsclassifier__n_neighbors': [1, 2, 3, 4, 5, 10, 15, 20, 25, 30]\n", "}\n", "\n", "# Perform grid search\n", "grid_search = GridSearchCV(pipeline, param_grid, cv=3, scoring='accuracy', n_jobs=1)\n", "\n", "# Print the grid_search to confirm no errors\n", "print(grid_search)" ] }, { "cell_type": "code", "execution_count": 210, "id": "607d643e-edb2-4cc3-9c74-854339c00f6e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
GridSearchCV(cv=3,\n",
       "             estimator=Pipeline(steps=[('kneighborsclassifier',\n",
       "                                        KNeighborsClassifier())]),\n",
       "             n_jobs=1,\n",
       "             param_grid={'kneighborsclassifier__n_neighbors': [1, 2, 3, 4, 5,\n",
       "                                                               10, 15, 20, 25,\n",
       "                                                               30]},\n",
       "             scoring='accuracy')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "GridSearchCV(cv=3,\n", " estimator=Pipeline(steps=[('kneighborsclassifier',\n", " KNeighborsClassifier())]),\n", " n_jobs=1,\n", " param_grid={'kneighborsclassifier__n_neighbors': [1, 2, 3, 4, 5,\n", " 10, 15, 20, 25,\n", " 30]},\n", " scoring='accuracy')" ] }, "execution_count": 210, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid_search.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "id": "700ae62a-9306-454b-a27f-49528d4106b6", "metadata": {}, "source": [ "We'll do lots of analysis on these results. The final best model is below." ] }, { "cell_type": "code", "execution_count": 213, "id": "d98582db-a04b-455b-8610-23d079f62262", "metadata": {}, "outputs": [], "source": [ "knn_cv_results = pd.DataFrame(grid_search.cv_results_)" ] }, { "cell_type": "code", "execution_count": 214, "id": "b00c6369-395a-42ea-ae3c-1075b859ffd6", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "knn_cv_results.plot(x='param_kneighborsclassifier__n_neighbors', y='mean_test_score',\n", " kind='bar',ylim=(0.9,0.96));" ] }, { "cell_type": "code", "execution_count": 215, "id": "5df79199-2787-48ce-bbe8-23d249062148", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_fit_timestd_fit_timemean_score_timestd_score_timeparam_kneighborsclassifier__n_neighborsparamssplit0_test_scoresplit1_test_scoresplit2_test_scoremean_test_scorestd_test_scorerank_test_score
00.0080940.0004040.3316920.0130441{'kneighborsclassifier__n_neighbors': 1}0.9288770.9318470.9404200.9337150.0048941
20.0078540.0002860.4410010.0908813{'kneighborsclassifier__n_neighbors': 3}0.9318770.9232750.9339910.9297140.0046342
40.0077640.0004230.3463110.0108525{'kneighborsclassifier__n_neighbors': 5}0.9318770.9194170.9279900.9264280.0052053
30.0090600.0010030.4097700.0989174{'kneighborsclassifier__n_neighbors': 4}0.9310200.9177030.9284180.9257140.0057634
50.0073900.0002230.3272270.00443310{'kneighborsclassifier__n_neighbors': 10}0.9233080.9112730.9249890.9198570.0061085
10.0087930.0013010.3331240.0162382{'kneighborsclassifier__n_neighbors': 2}0.9125960.9159880.9228460.9171440.0042636
60.0078680.0001410.3797070.05623415{'kneighborsclassifier__n_neighbors': 15}0.9177380.8988430.9134160.9099990.0080847
70.0075940.0002290.3593370.03960420{'kneighborsclassifier__n_neighbors': 20}0.9125960.8997000.9069870.9064280.0052808
80.0077590.0000490.3271070.00292025{'kneighborsclassifier__n_neighbors': 25}0.9087400.8906990.9031290.9008560.0075399
90.0076280.0003270.3277240.00492830{'kneighborsclassifier__n_neighbors': 30}0.8980290.8834120.8967000.8927140.00660010
\n", "
" ], "text/plain": [ " mean_fit_time std_fit_time mean_score_time std_score_time \\\n", "0 0.008094 0.000404 0.331692 0.013044 \n", "2 0.007854 0.000286 0.441001 0.090881 \n", "4 0.007764 0.000423 0.346311 0.010852 \n", "3 0.009060 0.001003 0.409770 0.098917 \n", "5 0.007390 0.000223 0.327227 0.004433 \n", "1 0.008793 0.001301 0.333124 0.016238 \n", "6 0.007868 0.000141 0.379707 0.056234 \n", "7 0.007594 0.000229 0.359337 0.039604 \n", "8 0.007759 0.000049 0.327107 0.002920 \n", "9 0.007628 0.000327 0.327724 0.004928 \n", "\n", " param_kneighborsclassifier__n_neighbors \\\n", "0 1 \n", "2 3 \n", "4 5 \n", "3 4 \n", "5 10 \n", "1 2 \n", "6 15 \n", "7 20 \n", "8 25 \n", "9 30 \n", "\n", " params split0_test_score \\\n", "0 {'kneighborsclassifier__n_neighbors': 1} 0.928877 \n", "2 {'kneighborsclassifier__n_neighbors': 3} 0.931877 \n", "4 {'kneighborsclassifier__n_neighbors': 5} 0.931877 \n", "3 {'kneighborsclassifier__n_neighbors': 4} 0.931020 \n", "5 {'kneighborsclassifier__n_neighbors': 10} 0.923308 \n", "1 {'kneighborsclassifier__n_neighbors': 2} 0.912596 \n", "6 {'kneighborsclassifier__n_neighbors': 15} 0.917738 \n", "7 {'kneighborsclassifier__n_neighbors': 20} 0.912596 \n", "8 {'kneighborsclassifier__n_neighbors': 25} 0.908740 \n", "9 {'kneighborsclassifier__n_neighbors': 30} 0.898029 \n", "\n", " split1_test_score split2_test_score mean_test_score std_test_score \\\n", "0 0.931847 0.940420 0.933715 0.004894 \n", "2 0.923275 0.933991 0.929714 0.004634 \n", "4 0.919417 0.927990 0.926428 0.005205 \n", "3 0.917703 0.928418 0.925714 0.005763 \n", "5 0.911273 0.924989 0.919857 0.006108 \n", "1 0.915988 0.922846 0.917144 0.004263 \n", "6 0.898843 0.913416 0.909999 0.008084 \n", "7 0.899700 0.906987 0.906428 0.005280 \n", "8 0.890699 0.903129 0.900856 0.007539 \n", "9 0.883412 0.896700 0.892714 0.006600 \n", "\n", " rank_test_score \n", "0 1 \n", "2 2 \n", "4 3 \n", "3 4 \n", "5 5 \n", "1 6 \n", "6 7 \n", "7 8 \n", "8 9 \n", "9 10 " ] }, "execution_count": 215, "metadata": {}, "output_type": "execute_result" } ], "source": [ "knn_cv_results.sort_values(by=\"mean_test_score\", ascending = False)" ] }, { "cell_type": "code", "execution_count": 216, "id": "cb38df2c-36ae-449f-ac07-aca19bc46a87", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best Parameters: {'kneighborsclassifier__n_neighbors': 1}\n" ] } ], "source": [ "best_params = grid_search.best_params_\n", "print(\"Best Parameters:\", best_params)" ] }, { "cell_type": "code", "execution_count": 217, "id": "f3fe816c-bd6c-45d8-a9cb-e8c9f3a9765e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best Model: Pipeline(steps=[('kneighborsclassifier', KNeighborsClassifier(n_neighbors=1))])\n" ] } ], "source": [ "best_model = grid_search.best_estimator_\n", "print(\"Best Model:\", best_model)" ] }, { "cell_type": "code", "execution_count": 218, "id": "fcd55cd8-8b7c-4aec-9c02-8d7a0a9ece2c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy with Best Model: 0.9328571428571428\n" ] } ], "source": [ "# Evaluate the best model\n", "y_pred = best_model.predict(X_test)\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Test accuracy with Best Model:\", accuracy)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }