Я пытаюсь обучить сеть несбалансированными данными. У меня есть A (198 выборок), B (436 выборок), C (710 выборок), D (272 выборки), и я прочитал о "weighted_cross_entropy_with_logits", но все примеры, которые я нашел, предназначены для двоичной классификации, поэтому я не очень уверенный в том, как установить эти веса.
Всего образцов: 1616
A_weight: 198/1616 = 0.12?
Идея, если я понял, наказывает ошибки класса мэрии и более позитивно оценивает удары в меньшинстве, верно?
Мой фрагмент кода:
weights = tf.constant([0.12, 0.26, 0.43, 0.17])
cost = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=pred, targets=y, pos_weight=weights))
Я прочитал этот и другие примеры с бинарной классификацией, но все еще не очень ясен.
Спасибо заранее.