Рассказы о математике | страница 37
Рассмотрим практический пример - распознавание выше написанной нейронной сетью рукописного текста. Для тренировки сети воспользуемся имеющейся в открытом доступе базой MNIST, содержащей 60000 черно-белых рукописных изображений цифр размером 28x28пкс.
Выглядят изображения примерно так:
Массив изображений каждой цифры хранится в виде чисел float в диапазоне от 0 до 1, длина массива равна 28х28 = 784 значения. Кстати, при выводе массива в консоль форма цифры вполне узнаваема:
(картинка соответствует числу “5”)
Всего таких чисел, как было сказано, 60000, что вполне хватит для тренировки нашей сети.
Воспользуемся уже готовым классом MLP, изменим лишь некоторые параметры:
- Число входов будет равно 28х28 = 784, т.к. на вход нейросети будет подаваться массив целиком.
- На выходе нейросеть должна давать значения от 0 до 9, однако мы не можем получить значение > 1, так что воспользуемся другим путем: нейросеть будет иметь 10 выходов, где “1” будет соответствовать указанной цифре. Т.е. для числа “0” мы получим [1,0,0,0,0,0,0,0,0,0], для “1” [0,1,0,0,0,0,0,0,0,0], и так далее.
- Число скрытых нейронов мы берем наугад, например 48. Это значение мы позже можем скорректировать по результатам тестирования сети.
В итоге, мы имеем такую структуру сети:
Рассмотрим подробнее программу на языке Python.
Тренировка сети
Создаем нейронную сеть и подготавливаем 2 массива тестовых данных (исходные и целевых значения). Массив “digits” загружается из зараннее подготовленного текстового файла базы MNIST. Для каждой цифры содержится исходный массив в 784 цифры, и правильное значение числа, которое и будет использоваться для тренировки.
mlp = MLP(n_in=28*28, n_hidden=48, n_out=10)
print "1. Загрузка данных"
with open("digitsMnist.txt", 'rb') as fp:
digits = cPickle.load(fp)
print "2. Подготовка данных"
batch = 50
inputs = []
targets = []
for p in range(batch):
data = digits[p]["data"]
resulVal = digits[p]["result"]
# Конвертация целевого числа в массив: “3” => [0,0,0,1,0,0,0,0,0,0]
result_flat = [0.0]*10
result_flat[resulVal] = 1.0
inputs.append(data)
targets.append(result_flat)
Собственно тренировка ничем по сути не отличается от предыдущего примера с функцией XOR. Т.к. процесс идет несколько часов, были также добавлены функции сохранения и чтения состояния сети, чтобы процесс можно было прервать и затем продолжить. Прервать тренировку можно в любой момент нажатием Ctrl+C. Для удобства визуализации процесса выводится суммарная ошибка по всем тестируемым цифрам.