10. Keras實作練習
建置一個手寫辨識
首先匯入 Keras 及相關模組:
Code[ part 1]:
import numpy as np
import pandas as pd
import sys, os
from keras.utils import np_utils
from keras.datasets import mnist
np.random.seed(10)
11. Keras實作練習
建置一個手寫辨識
讀取與查看 mnist 資料
可知道 training data 共有 60,000 筆; testing data 共有 10,000 筆
Code[ part 2]:
(X_train_image, y_train_label), (X_test_image, y_test_label) = mnist.load_data()
print("t[Info] train data={:7,}".format(len(X_train_image)))
print("t[Info] test data={:7,}".format(len(X_test_image)))
print("t[Info] Shape of train data=%s" % (str(X_train_image.shape)))
print("t[Info] Shape of train label=%s" % (str(y_train_label.shape)))
12. Keras實作練習
建置一個手寫辨識
定應 plot_image 函數顯示數字影像
Code[ part 3]:
def isDisplayAvl():
return 'DISPLAY' in os.environ.keys()
import matplotlib.pyplot as plt
def plot_image(image):
fig = plt.gcf()
fig.set_size_inches(2,2)
plt.imshow(image, cmap='binary')
plt.show()
13. 建立 plot_images_labels_predict() 函數
後續我們希望能很方便查看數字圖形, 所以我們建立了以下函數:
Code[ part 4]:
def plot_images_labels_predict(images, labels, prediction, idx, num=10):
fig = plt.gcf()
fig.set_size_inches(12, 14)
if num > 25: num = 25
for i in range(0, num):
ax=plt.subplot(5,5, 1+i)
ax.imshow(images[idx], cmap='binary')
title = "l=" + str(labels[idx])
if len(prediction) > 0:
title = "l={},p={}".format(str(labels[idx]), str(prediction[idx]))
else:
title = "l={}".format(str(labels[idx]))
ax.set_title(title, fontsize=10)
ax.set_xticks([]); ax.set_yticks([])
idx+=1
plt.show()
19. Keras實作練習
建置一個手寫辨識
評估模型準確率 、進行預測
Code[ part 10]:
print("t[Info] Making prediction to x_Test_norm")
prediction = model.predict_classes(x_Test_norm) # Making prediction and save result to prediction
print()
print("t[Info] Show 10 prediction result (From 240):")
print("%sn" % (prediction[240:250]))
if isDisplayAvl():
plot_images_labels_predict(X_test_image, y_test_label, prediction, idx=240)
print("t[Info] Error analysis:")
for i in range(len(prediction)):
if prediction[i] != y_test_label[i]:
print("tAt %d'th: %d is with wrong prediction as %d!" % (i, y_test_label[i], prediction[i]))
print()