Точность оценки в pyTorch LSTM

Я запускал этот учебник LSTM на wikigold.conll NER набор данных

training_data содержит список кортежей последовательностей и тегов, например:

training_data = [
    ("They also have a song called \" wake up \"".split(), ["O", "O", "O", "O", "O", "O", "I-MISC", "I-MISC", "I-MISC", "I-MISC"]),
    ("Major General John C. Scheidt Jr.".split(), ["O", "O", "I-PER", "I-PER", "I-PER"])
]

И я записал эту функцию

def predict(indices):
    """Gets a list of indices of training_data, and returns a list of predicted lists of tags"""
    for index in indicies:
        inputs = prepare_sequence(training_data[index][0], word_to_ix)
        tag_scores = model(inputs)
        values, target = torch.max(tag_scores, 1)
        yield target

Таким образом, я могу получить предсказанные метки для определенных индексов в данных обучения.

Однако, как я могу оценить показатель точности во всех учебных данных.

Точность: количество слов, правильно классифицированных по всем предложениям, разделенным на количество слов.

Вот что я придумал, что очень медленно и уродливо:

y_pred = list(predict([s for s, t in training_data]))
y_true = [t for s, t in training_data]
c=0
s=0
for i in range(len(training_data)):
    n = len(y_true[i])
    #super ugly and ineffiicient
    s+=(sum(sum(list(y_true[i].view(-1, n) == y_pred[i].view(-1, n).data))))
    c+=n

print ('Training accuracy:{a}'.format(a=float(s)/c))

Как это можно сделать эффективно в pytorch?

P.S: Я пытался безуспешно использовать sklearn precision_score

Ответ 1

Я бы использовал numpy, чтобы не перебирать список в чистом питоне.

Результаты те же, но они работают намного быстрее

def accuracy_score(y_true, y_pred):
    y_pred = np.concatenate(tuple(y_pred))
    y_true = np.concatenate(tuple([[t for t in y] for y in y_true])).reshape(y_pred.shape)
    return (y_true == y_pred).sum() / float(len(y_true))

И вот как это использовать:

#original code:
y_pred = list(predict([s for s, t in training_data]))
y_true = [t for s, t in training_data]
#numpy accuracy score
print(accuracy_score(y_true, y_pred))