Сеть Keras никогда не может классифицировать последний класс

Я работаю над своим проектом Deep Learning Language Detection, который представляет собой сеть с этими слоями для распознавания с 16 языков программирования:

введите описание изображения здесь

И это код для создания сети:

# Setting up the model
graph_in = Input(shape=(sequence_length, number_of_quantised_characters))
convs = []
for i in range(0, len(filter_sizes)):
    conv = Conv1D(filters=num_filters,
                  kernel_size=filter_sizes[i],
                  padding='valid',
                  activation='relu',
                  strides=1)(graph_in)
    pool = MaxPooling1D(pool_size=pooling_sizes[i])(conv)
    flatten = Flatten()(pool)
    convs.append(flatten)

if len(filter_sizes)>1:
    out = Concatenate()(convs)
else:
    out = convs[0]

graph = Model(inputs=graph_in, outputs=out)

# main sequential model
model = Sequential()


model.add(Dropout(dropout_prob[0], input_shape=(sequence_length, number_of_quantised_characters)))
model.add(graph)
model.add(Dense(hidden_dims))
model.add(Dropout(dropout_prob[1]))
model.add(Dense(number_of_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])

Таким образом, мой последний языковой класс - это SQL и на этапе тестирования он никогда не может предсказать SQL правильно, и он начисляет 0% на них. Я думал, что это связано с низким качеством образцов SQL (и действительно, они были бедными), поэтому я удалил этот класс и начал обучение по 15 классам. К моему удивлению, теперь F # файлы имели 0% -е обнаружение, а F # был последним классом после удаления SQL (т.е. С одним горячим вектором, где последняя позиция равна 1, а остальное - 0). Теперь, если сеть, которая была обучена 16, использовала против 15, она достигнет очень высокого уровня успеха 98,5%.

Код, который я использую, довольно прост и доступен в основном defs.py и data_helper.py

Вот результат обучения сети с 16 классами, протестированными против 16 классов:

Final result: 14827/16016 (0.925761738262)
xml:        995/1001 (0.994005994006)
fsharp:     974/1001 (0.973026973027)
clojure:        993/1001 (0.992007992008)
java:       996/1001 (0.995004995005)
scala:      990/1001 (0.989010989011)
python:     983/1001 (0.982017982018)
sql:        0/1001 (0.0)
js:     991/1001 (0.99000999001)
cpp:        988/1001 (0.987012987013)
css:        987/1001 (0.986013986014)
csharp:     994/1001 (0.993006993007)
go:     989/1001 (0.988011988012)
php:        998/1001 (0.997002997003)
ruby:       995/1001 (0.994005994006)
powershell:     992/1001 (0.991008991009)
bash:       962/1001 (0.961038961039)

И это результат той же сети (тренировался против 16), которая проходила против 15 классов:

Final result: 14827/15015 (0.987479187479)
xml:        995/1001 (0.994005994006)
fsharp:     974/1001 (0.973026973027)
clojure:        993/1001 (0.992007992008)
java:       996/1001 (0.995004995005)
scala:      990/1001 (0.989010989011)
python:     983/1001 (0.982017982018)
js:     991/1001 (0.99000999001)
cpp:        988/1001 (0.987012987013)
css:        987/1001 (0.986013986014)
csharp:     994/1001 (0.993006993007)
go:     989/1001 (0.988011988012)
php:        998/1001 (0.997002997003)
ruby:       995/1001 (0.994005994006)
powershell:     992/1001 (0.991008991009)
bash:       962/1001 (0.961038961039)

Кто-нибудь еще видел это? Как я могу обойти это?

Ответ 1

TL; DR: Проблема заключается в том, что ваши данные не перетасовываются, прежде чем разделяться на обучающие и валидационные наборы. Поэтому во время обучения все образцы, принадлежащие классу "sql" , находятся в наборе проверки. Ваша модель не научится прогнозировать последний класс, если в этом классе не были предоставлены образцы.


В get_input_and_labels() сначала загружаются файлы для класса 0, а затем класс 1 и т.д. Поскольку вы устанавливаете n_max_files = 2000, это означает, что

  • Первый 2000 (или так, зависит от того, сколько файлов у вас есть) в Y будет иметь класс 0 ( "go" )
  • Следующие 2000 записей будут иметь класс 1 ( "csharp" )
  • ...
  • и, наконец, последние 2000 записей будут иметь последний класс ( "sql" ).

К сожалению, Keras не перетасовывает данные, прежде чем разбивать их на тренировки и проверки. Поскольку в вашем коде validation_split установлено значение 0,1, то последние 3000 выборок (которые содержат все образцы "sql" ) будут в наборе проверки.

Если вы установите validation_split на более высокое значение (например, 0.2), вы увидите больше классов, скорректировавших 0%:

Final result: 12426/16016 (0.7758491508491508)
go:             926/1001 (0.9250749250749251)
csharp:         966/1001 (0.965034965034965)
java:           973/1001 (0.972027972027972)
js:             929/1001 (0.9280719280719281)
cpp:            986/1001 (0.985014985014985)
ruby:           942/1001 (0.9410589410589411)
powershell:             981/1001 (0.98001998001998)
bash:           882/1001 (0.8811188811188811)
php:            977/1001 (0.9760239760239761)
css:            988/1001 (0.987012987012987)
xml:            994/1001 (0.993006993006993)
python:         986/1001 (0.985014985014985)
scala:          896/1001 (0.8951048951048951)
clojure:                0/1001 (0.0)
fsharp:         0/1001 (0.0)
sql:            0/1001 (0.0)

Проблема может быть решена, если вы перетасовываете данные после загрузки. Кажется, что у вас уже есть строки, перетасовывающие данные:

# Shuffle data
shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices].argmax(axis=1)

Однако, когда вы подходите к модели, вы передавали исходные x и Y в fit() вместо x_shuffled и y_shuffled. Если вы измените строку на:

model.fit(x_shuffled, y_shuffled, batch_size=batch_size,
          epochs=num_epochs, validation_split=val_split, verbose=1)

Результаты тестирования станут более разумными:

Final result: 15248/16016 (0.952047952047952)
go:             865/1001 (0.8641358641358642)
csharp:         986/1001 (0.985014985014985)
java:           977/1001 (0.9760239760239761)
js:             953/1001 (0.952047952047952)
cpp:            974/1001 (0.973026973026973)
ruby:           985/1001 (0.984015984015984)
powershell:             974/1001 (0.973026973026973)
bash:           942/1001 (0.9410589410589411)
php:            979/1001 (0.978021978021978)
css:            965/1001 (0.964035964035964)
xml:            988/1001 (0.987012987012987)
python:         857/1001 (0.8561438561438561)
scala:          955/1001 (0.954045954045954)
clojure:                985/1001 (0.984015984015984)
fsharp:         950/1001 (0.949050949050949)
sql:            913/1001 (0.9120879120879121)