Recognizing Digits with k-NN¶
The MNIST dataset is one of the most widely used datasets in ML classes and has been for years. It consists of 60,000 digitized scans of handwritten digits along with their classifications: 0-9. Sci-kit learn includes a clean subset of this data in its datasets
package. In this notebook you will recognize digits using a k-NN classifier.
import numpy as np
from sklearn.datasets import load_digits
from matplotlib import pyplot as plt
The method to return the dataset is load_digits()
. Call it in the next cell
# call load_digits() and set the return value to a new variable `digits`
Look at the data stored in digits
and load the appropriate values into the variables X
and y
. $X$ should be an array of vectors, representing the flattened 8x8 image data and $y$ should be an array of labels 0-9.
### Store X and y here
The following cells let you explore the dataset.
images = digits.images
k = 100 # change this
plt.imshow(images[k], cmap='grey')
print(f"data = {X[k]}")
print(f"label = {y[k]}")
Determine the size and shape of your data. How many X and y observations are there? What are their datatypes? Is there any missing data?
## your code here
Is the dataset balanced? Determine the frequency of each label
## your code here
Training k-NN¶
We will load the dataset below and some useful library functions. It is your job to make a train/test split and determine the accuracy of a k-NN classifier using scikit learn built in operations. You should
- Correctly process the data and train a k-NN for k=1.
- Print the accuracy on the test set
- Plot a confusion matrix for the test set
- Analyze the accuracy as $k$ increases. For each $k$ (over a range you determine useful), train a $k$-NN on 10 train/test splits and store the average accuracy. Plot this accuracy against k.
- Which $k$ do you think is best?
- Is "accuracy" a valid metric in this case? Does it obscure any critical shortcoming of our algorithm?
Refer to old notebooks for the functions you need. The basic outline for this is mostly the same for all scikit-learn algorithms.
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# complete here
Extension¶
Find the full 60,000 sample dataset online and try it with your preferred $k$-NN. This dataset exists in many places but some are cleaner and easier to use than others so be prepared to shop around. Print out
- the accuracy
- the time spent training
- the time spent testing
- a confusion matrix on the test set