Я пытаюсь реализовать FCNN для классификации изображений, которые могут принимать входы переменного размера. Модель построена в Keras с поддержкой TensorFlow.
Рассмотрим следующий пример игрушки:
model = Sequential()
# width and height are None because we want to process images of variable size
# nb_channels is either 1 (grayscale) or 3 (rgb)
model.add(Convolution2D(32, 3, 3, input_shape=(nb_channels, None, None), border_mode='same'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(32, 3, 3, border_mode='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(16, 1, 1))
model.add(Activation('relu'))
model.add(Convolution2D(8, 1, 1))
model.add(Activation('relu'))
# reduce the number of dimensions to the number of classes
model.add(Convolution2D(nb_classses, 1, 1))
model.add(Activation('relu'))
# do global pooling to yield one value per class
model.add(GlobalAveragePooling2D())
model.add(Activation('softmax'))
Эта модель работает нормально, но я столкнулся с проблемой производительности. Обучение изображениям переменной величины требует неоправданно долгого времени по сравнению с обучением на входах фиксированного размера. Если я изменю размер всех изображений на максимальный размер в наборе данных, для обучения модели по-прежнему требуется гораздо меньше времени, чем обучение на входе с переменным размером. Так что input_shape=(nb_channels, None, None)
правильный способ указать ввод переменных размера? И есть ли способ смягчить эту проблему производительности?
Обновить
model.summary()
для модели с 3 классами и полутоновыми изображениями:
Layer (type) Output Shape Param # Connected to
====================================================================================================
convolution2d_1 (Convolution2D) (None, 32, None, None 320 convolution2d_input_1[0][0]
____________________________________________________________________________________________________
activation_1 (Activation) (None, 32, None, None 0 convolution2d_1[0][0]
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D) (None, 32, None, None 0 activation_1[0][0]
____________________________________________________________________________________________________
convolution2d_2 (Convolution2D) (None, 32, None, None 9248 maxpooling2d_1[0][0]
____________________________________________________________________________________________________
maxpooling2d_2 (MaxPooling2D) (None, 32, None, None 0 convolution2d_2[0][0]
____________________________________________________________________________________________________
convolution2d_3 (Convolution2D) (None, 16, None, None 528 maxpooling2d_2[0][0]
____________________________________________________________________________________________________
activation_2 (Activation) (None, 16, None, None 0 convolution2d_3[0][0]
____________________________________________________________________________________________________
convolution2d_4 (Convolution2D) (None, 8, None, None) 136 activation_2[0][0]
____________________________________________________________________________________________________
activation_3 (Activation) (None, 8, None, None) 0 convolution2d_4[0][0]
____________________________________________________________________________________________________
convolution2d_5 (Convolution2D) (None, 3, None, None) 27 activation_3[0][0]
____________________________________________________________________________________________________
activation_4 (Activation) (None, 3, None, None) 0 convolution2d_5[0][0]
____________________________________________________________________________________________________
globalaveragepooling2d_1 (Global (None, 3) 0 activation_4[0][0]
____________________________________________________________________________________________________
activation_5 (Activation) (None, 3) 0 globalaveragepooling2d_1[0][0]
====================================================================================================
Total params: 10,259
Trainable params: 10,259
Non-trainable params: 0