# K-Nearest Neighbor classification on *MNIST* handwritten image dataset

In this notebook we will use *kNN* to classify the *MNIST* (Modified National Institute of Standards and Technology) database of handwritten digits (http://yann.lecun.com/exdb/mnist/ ). The full dataset consists of 60,000 training samples and 10,000 test samples of images that were 28 x 28 pixels wide; a simplified version was produced containing 5,620 images that are 8 x 8 pixels wide (http://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits ). The *scikit-learn* module contains a subset of these images. 

## Loading and understanding the data


In [None]:
from sklearn import datasets
digits = datasets.load_digits()

# This describes the processed dataset containing 5,620 images; our dataset is smaller as will be seen below
print(digits.DESCR)

Extract the features and the target values (class labels)

In [None]:
X = digits.data
y = digits.target

*X* is two-dimensional and has 1797 rows and 64 columns. The 64 columns represents a 'flattened' version of the 8x8 image.

In [None]:
X.shape

*y* has 1797 values

In [None]:
y.shape

## Visualizing the data

The image representation can be accessed by looking at *digits.images*, a **three-dimensional** array of 1797 images, where each image is an 8x8 array.

In [None]:
digits.images.shape

Let's look at the first image, which is an 8x8 array of pixel intensities. Slicing rules for arrays apply, except now we have 3 dimensions. Furthermore, to access a single image at index *i*, note that

```python
X[i]
```
is the same as

```python
X[i,:,:]
```

In [None]:
digits.images[0]

Let's *view* the first image:

In [None]:
import matplotlib.pyplot as plt
plt.imshow(digits.images[0], cmap = plt.cm.gray_r)
plt.axis('off')
plt.title('Number: ' + str(y[0]))
None

Let's generate images for the first 30 numbers. We use the method *subplots* to specify we want to plot 3 rows of images with 10 images in each row; each image will be 15x6. This method returns a tuple containing two elements: the *figure*, and an array of *axes* objects (for each subfigure). The *ravel* method is used to *flatten* the axes array (more on this below)

In [None]:
# set up the plot
figure, axes = plt.subplots(3,10, figsize = (15,6))

for ax,image,number in zip(axes.ravel(), digits.images, y) :
    ax.axis('off')
    ax.imshow(image, cmap = plt.cm.gray_r)
    ax.set_title('Number: ' + str(number))
    

## Flattening numpy arrays

Sometimes data in multidimensional arrays need to be *flattened* to a one-dimensional array. In the code above, axes is a two-dimensional array with 3 rows and 10 columns. In order to iterate over each figure, we need to flatten the array. 

In classification, the feature data for a *single* sample must be stored in a one-dimensional array (even if this is not the actual structure), and the feature data for *all* samples must be a two-dimensional array. But what if the feature data for a sample is an 8x8 image? The solution is to *flatten* the data. The *ravel* method will flatten data by reading the array row by row. The feature data (stored in *X*) contains the flattened version of each array.

In [None]:
image = digits.images[0]
print('original image data =')
print(image)
print()

image_flattened = image.ravel()
print('flattened image = ')
print(image_flattened)
print()

print('feature data = ')
print(X[0])

## K-nearest neighbor classification on *MNIST* using training and testing datasets


### Split the data into training and testing sets

In [None]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=99, stratify = y)

### Fit the model

In [None]:
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)

### Make predictions in the *test* dataset

In [None]:
y_pred = knn.predict(X_test)

### Evaluate the results by generating a *classification report*  which calculates various performance measures

Our *kNN* classifier correctly identifies most digits 99% of the time, on average!

In [None]:
from sklearn.metrics import classification_report
report = classification_report(y_test, y_pred)
print(report)

## Evaluate the results by looking at the *confusion matrix*

A *confusion matrix* is a matrix that shows how the observations in each row (each class) were classified (corresponding to each column). As the name implies, confusion matrices are useful for identifying areas where the classifier may be "confused" (i.e., where it consistently misclassifies a particular category)

In [None]:
from sklearn.metrics import confusion_matrix
confusion = confusion_matrix(y_true = y_test, y_pred = y_pred)
confusion

We can visualize the heatmap using the seaborn *heatmap* function. 

We do not create a data frame because we did not need to assign row and column names (since the default values, 0 - 9, correspond to the class values.

Are there are numbers which tend to be mis-classified?

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
s = sns.heatmap(confusion, annot = True, cmap = 'nipy_spectral_r')
s.set_title('Confusion matrix for MNIST dataset')
plt.ylabel('True Value')
plt.xlabel('Predicted Value')
None

### Getting specific performance measures

We can get the accuracy (number correct / number of observations) by using either the *knn.score* method or the *metrics.accuracy.score* method.

In [None]:
# calculate the overall accuracy using knn.score
acc = knn.score(X_test, y_test)
print(f'accuracy from knn.score = {acc:.4}') 

# calculate the overall accuracy using metrics.accuracy_score
from sklearn import metrics
acc = metrics.accuracy_score(y_test, y_pred)
print(f'accuracy from metrics.accuracy_score = {acc:.4}') 

However, overal accuracy may not a good measure. Why?

Balanced accuracy is a type of accuracy that assumes the number of samples for each target is the same

In [None]:
# calculate the balanced accuracy using metrics.accuracy_score
acc = metrics.balanced_accuracy_score(y_test, y_pred)
print(f'accuracy = {acc:.4}') 