Рассказы о математике | страница 37



Рассмотрим практический пример - распознавание выше написанной нейронной сетью рукописного текста. Для тренировки сети воспользуемся имеющейся в открытом доступе базой MNIST, содержащей 60000 черно-белых рукописных изображений цифр размером 28x28пкс.

Выглядят изображения примерно так:


Массив изображений каждой цифры хранится в виде чисел float в диапазоне от 0 до 1, длина массива равна 28х28 = 784 значения. Кстати, при выводе массива в консоль форма цифры вполне узнаваема:

(картинка соответствует числу “5”)


Всего таких чисел, как было сказано, 60000, что вполне хватит для тренировки нашей сети.

Воспользуемся уже готовым классом MLP, изменим лишь некоторые параметры:

Число входов будет равно 28х28 = 784, т.к. на вход нейросети будет подаваться массив целиком.

На выходе нейросеть должна давать значения от 0 до 9, однако мы не можем получить значение > 1, так что воспользуемся другим путем: нейросеть будет иметь 10 выходов, где “1” будет соответствовать указанной цифре. Т.е. для числа “0” мы получим [1,0,0,0,0,0,0,0,0,0], для “1” [0,1,0,0,0,0,0,0,0,0], и так далее.

Число скрытых нейронов мы берем наугад, например 48. Это значение мы позже можем скорректировать по результатам тестирования сети.


В итоге, мы имеем такую структуру сети:


Рассмотрим подробнее программу на языке Python.

Тренировка сети


Создаем нейронную сеть и подготавливаем 2 массива тестовых данных (исходные и целевых значения). Массив “digits” загружается из зараннее подготовленного текстового файла базы MNIST. Для каждой цифры содержится исходный массив в 784 цифры, и правильное значение числа, которое и будет использоваться для тренировки.


mlp = MLP(n_in=28*28, n_hidden=48, n_out=10)


print "1. Загрузка данных"


with open("digitsMnist.txt", 'rb') as fp:

digits = cPickle.load(fp)


print "2. Подготовка данных"


batch = 50

inputs = []

targets = []

for p in range(batch):

data = digits[p]["data"]

resulVal = digits[p]["result"]

# Конвертация целевого числа в массив: “3” => [0,0,0,1,0,0,0,0,0,0]

result_flat = [0.0]*10

result_flat[resulVal] = 1.0


inputs.append(data)

targets.append(result_flat)


Собственно тренировка ничем по сути не отличается от предыдущего примера с функцией XOR. Т.к. процесс идет несколько часов, были также добавлены функции сохранения и чтения состояния сети, чтобы процесс можно было прервать и затем продолжить. Прервать тренировку можно в любой момент нажатием Ctrl+C. Для удобства визуализации процесса выводится суммарная ошибка по всем тестируемым цифрам.