Я работаю над своим проектом Deep Learning Language Detection, который представляет собой сеть с этими слоями для распознавания с 16 языков программирования:
И это код для создания сети:
# Setting up the model
graph_in = Input(shape=(sequence_length, number_of_quantised_characters))
convs = []
for i in range(0, len(filter_sizes)):
conv = Conv1D(filters=num_filters,
kernel_size=filter_sizes[i],
padding='valid',
activation='relu',
strides=1)(graph_in)
pool = MaxPooling1D(pool_size=pooling_sizes[i])(conv)
flatten = Flatten()(pool)
convs.append(flatten)
if len(filter_sizes)>1:
out = Concatenate()(convs)
else:
out = convs[0]
graph = Model(inputs=graph_in, outputs=out)
# main sequential model
model = Sequential()
model.add(Dropout(dropout_prob[0], input_shape=(sequence_length, number_of_quantised_characters)))
model.add(graph)
model.add(Dense(hidden_dims))
model.add(Dropout(dropout_prob[1]))
model.add(Dense(number_of_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])
Таким образом, мой последний языковой класс - это SQL и на этапе тестирования он никогда не может предсказать SQL правильно, и он начисляет 0% на них. Я думал, что это связано с низким качеством образцов SQL (и действительно, они были бедными), поэтому я удалил этот класс и начал обучение по 15 классам. К моему удивлению, теперь F # файлы имели 0% -е обнаружение, а F # был последним классом после удаления SQL (т.е. С одним горячим вектором, где последняя позиция равна 1, а остальное - 0). Теперь, если сеть, которая была обучена 16, использовала против 15, она достигнет очень высокого уровня успеха 98,5%.
Код, который я использую, довольно прост и доступен в основном defs.py и data_helper.py
Вот результат обучения сети с 16 классами, протестированными против 16 классов:
Final result: 14827/16016 (0.925761738262)
xml: 995/1001 (0.994005994006)
fsharp: 974/1001 (0.973026973027)
clojure: 993/1001 (0.992007992008)
java: 996/1001 (0.995004995005)
scala: 990/1001 (0.989010989011)
python: 983/1001 (0.982017982018)
sql: 0/1001 (0.0)
js: 991/1001 (0.99000999001)
cpp: 988/1001 (0.987012987013)
css: 987/1001 (0.986013986014)
csharp: 994/1001 (0.993006993007)
go: 989/1001 (0.988011988012)
php: 998/1001 (0.997002997003)
ruby: 995/1001 (0.994005994006)
powershell: 992/1001 (0.991008991009)
bash: 962/1001 (0.961038961039)
И это результат той же сети (тренировался против 16), которая проходила против 15 классов:
Final result: 14827/15015 (0.987479187479)
xml: 995/1001 (0.994005994006)
fsharp: 974/1001 (0.973026973027)
clojure: 993/1001 (0.992007992008)
java: 996/1001 (0.995004995005)
scala: 990/1001 (0.989010989011)
python: 983/1001 (0.982017982018)
js: 991/1001 (0.99000999001)
cpp: 988/1001 (0.987012987013)
css: 987/1001 (0.986013986014)
csharp: 994/1001 (0.993006993007)
go: 989/1001 (0.988011988012)
php: 998/1001 (0.997002997003)
ruby: 995/1001 (0.994005994006)
powershell: 992/1001 (0.991008991009)
bash: 962/1001 (0.961038961039)
Кто-нибудь еще видел это? Как я могу обойти это?