Recognizing Handwritten Digits in Python

The goal of this article is to show you how to use a couple of the tools available in Python and scikit-learn.  In particular, you'll write a simple program to learn how to recognize handwritten digits.  The idea for this particular example did not come from me, there are a few flavors of this example on the web.  If you don't have Python or scikit-learn you'll need to install them first.  There are various places to download these tools, e.g. https://code.google.com/p/pythonxy/.

This example is so simple that you could just skip down to the source code below and read it (with comments) and understand it.  However, many of the steps are elaborated on here to give some additional information.

Step 1: Load the data
The scikit-learn package includes a sub-package named sklearn.datasets with various sample data sets and methods to produce artificial datasets with various statistical properties.  For more information see http://scikit-learn.org/stable/datasets/.  These datasets are useful for testing out algorithms and learning methodologies, as used here.  Here the load_digits() function is used to extract some 8x8 resolution scans of handwritten digits along with the actual numbers that were written.  A function is included below to draw the digits from the scans if you are interested.

Step 2: Reduce the dimensionality
The scans of a handwritten digit consists of 64 real values (eight rows by eight columns) in the input.  Intuitively you might expect this representation to be overkill in some ways, e.g. maybe the upper left pixel is almost never used or more generally maybe there are a set of N pixels that are highly correlated across all the images.  Keeping all these (somewhat redundant) dimensions can complicate the learning algorithm.  If N dimensions can be well represented by fewer than N dimensions it's good to go ahead and do that.  The sklearn.decomposisiton package has functionality to do things like this for you automatically, based on analyzing the data you give it.  The method used here is called principle component analysis (PCA) and does this with a particular criteria, see https://en.wikipedia.org/wiki/Principal_component_analysis.

In this example the number of dimensions was just reduced to 20 pretty arbitrarily.  In reality, you'd want to investigate how many dimensions you can remove without "loosing too much information."  You can do this to some extent with the scikit-learn PCA methods, but we won't get into that here.

Step 3:  Learn from the (reduced) data
Although neural networks are probably the most popular machine learning methods, support vector machines (used here) give better results for many applications.  The theory will not be discussed here, but a very good introduction to machine learning (including lectures) is available here:  https://work.caltech.edu/lectures.html.  Many learning algorithms are available in scikit-learn, see the scikit-learn documentation for more information about that.

Step 4:  Evaluate the results
Once the learning algorithm has been applied, you'll typically want to test it with some new data (not the data is used in the learning process).  For this reason the available data was earlier split into a set of training data and a set of test data.  scikit-learn has various methods for splitting the data, below we simply took 25% of the data as one test set and the remaining data as one training set.  After the learning algorithm used the training data set, we want to see how well is does on the test set.  A simple metric is the fraction of digits that are correctly identified from the test set, which is exactly what the score() method returns, see the source code below.  Note the "x-values" given to the method must be transformed with the PCA (or whatever the training points were transformed with) to be consistent -- the SVM doesn't know about the transformation function.

 # Get the data that will be used below  
 from sklearn import datasets  
 dataset = datasets.load_digits()  
 x = dataset.data  #arrays of pixels (8x8 scans of handwritten digits)  
 y = dataset.target #the actual numbers the person was trying to write  
 # pick 25% of the data for testing, remaining for training  
 from sklearn.cross_validation import train_test_split  
 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.25)  
 # Reduce the dimensionality with PCA (8x8=64 dimensions down to 20)  
 from sklearn.decomposition import PCA  
 pca = PCA(n_components=20, whiten=True)  
 pca.fit(x_train)  
 tx_train = pca.transform(x_train) # get transformed 20-d x's  
 # Use a support vector machine (classification) to learn from the training data  
 from sklearn import svm  
 svc = svm.SVC(kernel='rbf') #use kernel of radial bases functions  
 svc.fit(tx_train, y_train)  #learn from training data  
 # Check the results  
 tx_test = pca.transform(x_test)  
 print(svc.score(tx_test, y_test)) # output ratio of test points predicted properly  
 # Investigate a digit that was mis-identified, show what it looked like  
 import matplotlib.pyplot as plt  
 def show_digit(x_data, predicted, actual):  
   """ Display the handwritten digit and what it should have been"""  
   digit = x_data.reshape(8, 8)  
   plt.imshow(digit, cmap=plt.cm.gray, interpolation='nearest')  
   plt.axis('off')  
   plt.title('Predicted '+str(predicted)+' actual '+str(actual))  
   plt.show()  
 y_test_pred = svc.predict(tx_test)  
 for i in xrange(len(x_test)):  
   if y_test_pred[i] != y_test[i]:  
     show_digit(x_test[i], y_test_pred[i], y_test[i])  
     break  

Comments

Popular Posts