Metrics

Top-k 实现代码

Numpy 实现

def get_top_k_result(logits, k=3, sorted=True):
    indices = np.argsort(logits, axis=-1)[:, -k:]  # 取概率最大的前K个所对应的预测标签
    if sorted:  # np.argsort 默认返回的顺序是从小到大,sorted=True 可以返回从大到小
        tmp = []
        for item in indices:
            tmp.append(item[::-1])
        indices = np.array(tmp)
    values = []
    for idx, item in zip(indices, logits):  # 取所有预测值所对应的概率值
        p = item.reshape(1, -1)[:, idx].reshape(-1)
        values.append(p)
    values = np.array(values)
    return values, indices

logits = np.array([[0.1, 0.3, 0.2, 0.4],
                   [0.5, 0.01, 0.9, 0.4]])
y = np.array([3, 0])
print(get_top_k_result(logits))

>>> (array([[0.4, 0.3, 0.2],
       [0.9, 0.5, 0.4]]), array([[3, 1, 2],
       [2, 0, 3]], dtype=int64))

def calculate_top_k_accuracy(logits, targets, k=2):
    values, indices = get_top_k_result(logits, k=k, sorted=False)
    y = np.reshape(targets, [-1, 1])
    correct = (y == indices) * 1.  # 对比预测的K个值中是否包含有正确标签中的结果
    top_k_accuracy = np.mean(correct) * k  # 计算最后的准确率
    return top_k_accuracy

print(calculate_top_k_accuracy(logits, y, k=2)) # 1.0
print(calculate_top_k_accuracy(logits, y, k=1)) # 0.5

Tensorflow

import tensorflow as tf

logits = tf.constant([[0.1, 0.3, 0.2, 0.4],
                      [0.5, 0.01, 0.9, 0.4]], shape=[2, 4], dtype=tf.float32)
y = tf.constant([3, 0], tf.int32)

def calculate_top_k_accuracy(logits, targets, k=2):
    values, indices = tf.math.top_k(logits, k=k, sorted=True)
    y = tf.reshape(targets, [-1, 1])
    correct = tf.cast(tf.equal(y, indices), tf.float32)
    top_k_accuracy = tf.reduce_mean(correct) * k
    return top_k_accuracy

sess = tf.Session()
print(sess.run(calculate_top_k_accuracy(logits, y, k=2)))# 1.0
print(sess.run(calculate_top_k_accuracy(logits, y, k=1)))# 0.5

Pytorch 中的实现


import torch
logits = torch.tensor([[0.1, 0.3, 0.2, 0.4],
                       [0.5, 0.01, 0.9, 0.4]])
y = torch.tensor([3, 0])
def calculate_top_k_accuracy(logits, targets, k=2):
    values, indices = torch.topk(logits, k=k, sorted=True)
    y = torch.reshape(targets, [-1, 1])
    correct = (y == indices) * 1.  # 对比预测的K个值中是否包含有正确标签中的结果
    top_k_accuracy = torch.mean(correct) * k  # 计算最后的准确率
    return top_k_accuracy

print(calculate_top_k_accuracy(logits, y, k=2).item())# 1.0
print(calculate_top_k_accuracy(logits, y, k=1).item())# 0.5