Coder Social home page Coder Social logo

Comments (19)

iamstarlee avatar iamstarlee commented on June 27, 2024

是测试集的结果吗?测试集没有ground truth,只能用验证集验证测试两用

from multi-label-sewer-classification.

qianmeng19 avatar qianmeng19 commented on June 27, 2024

是的是测试集结果,我使用Train13和Train13作为训练集和测试集,使用Xie2019模型,结果却如图,不应该是为0或者1才是对的嘛

from multi-label-sewer-classification.

iamstarlee avatar iamstarlee commented on June 27, 2024

模型的输出结果会经过sigmoid激活到0~1之间,0或者1是真值,模型好的预测只会很靠近0或者1,几乎不会为0或者1

from multi-label-sewer-classification.

qianmeng19 avatar qianmeng19 commented on June 27, 2024

那怎样的结果才是模型好的预测呢

from multi-label-sewer-classification.

qianmeng19 avatar qianmeng19 commented on June 27, 2024

可以看看博主的测试结果吗

from multi-label-sewer-classification.

iamstarlee avatar iamstarlee commented on June 27, 2024

Ng[k] = np.sum(tmp_targets == 1)
Np[k] = np.sum(tmp_scores >= threshold) # when >= 0 for the raw input, the sigmoid value will be >= 0.5
Nc[k] = np.sum(tmp_targets * (tmp_scores >= threshold))

需要先用inference.py进行推理,然后用calculate_results.py计算推理结果,从而判断模型的好坏,metrics.py里计算指标的时候会把高于阈值的算为正样本,低于的算为负样本

from multi-label-sewer-classification.

iamstarlee avatar iamstarlee commented on June 27, 2024

Snipaste_2024-06-06_09-39-48
我在计算指标之前就用阈值将结果置为0或者1了

from multi-label-sewer-classification.

qianmeng19 avatar qianmeng19 commented on June 27, 2024

感谢博主分享,请问怎样在计算指标之前就用阈值将结果置为0或者1,还有就是阈值一般设置为多少

from multi-label-sewer-classification.

iamstarlee avatar iamstarlee commented on June 27, 2024

def evaluation(scores, targets, weights, threshold = 0.5):
assert scores.shape == targets.shape, "The input and targets do not have the same size: Input: {} - Targets: {}".format(scores.shape, targets.shape)
_, n_class = scores.shape
# Arrays to hold binary classification information, size n_class +1 to also hold the implicit normal class
Nc = np.zeros(n_class+1) # Nc = Number of Correct Predictions - True positives
Np = np.zeros(n_class+1) # Np = Total number of Predictions - True positives + False Positives
Ng = np.zeros(n_class+1) # Ng = Total number of Ground Truth occurences
# False Positives = Np - Nc
# False Negatives = Ng - Nc
# True Positives = Nc
# True Negatives = n_examples - Np + (Ng - Nc)
# Array to hold the average precision metric. only size n_class, since it is not possible to calculate for the implicit normal class
ap = np.zeros(n_class)
for k in range(n_class):
tmp_scores = scores[:, k]
tmp_targets = targets[:, k]
tmp_targets[tmp_targets == -1] = 0 # Necessary if using MultiLabelSoftMarginLoss, instead of BCEWithLogitsLoss
Ng[k] = np.sum(tmp_targets == 1)
Np[k] = np.sum(tmp_scores >= threshold) # when >= 0 for the raw input, the sigmoid value will be >= 0.5
Nc[k] = np.sum(tmp_targets * (tmp_scores >= threshold))
ap[k] = average_precision(tmp_scores, tmp_targets)
#print("the Ng is {}".format(Ng))

第125行,Np[k]是第k类的Total number of Predictions,这里就是在取预测值大于阈值(一般为0.5)的结果

from multi-label-sewer-classification.

qianmeng19 avatar qianmeng19 commented on June 27, 2024

感谢博主的回答,但我想像博主一样将[metrics.py]中的评估函数运用到inference.py中进行推理,需要怎样修改代码呢

from multi-label-sewer-classification.

iamstarlee avatar iamstarlee commented on June 27, 2024

运行inference.py就会得到推理结果,再运行calculate_results.py就能得到推理结果的指标,里面会调用metrics.py

from multi-label-sewer-classification.

qianmeng19 avatar qianmeng19 commented on June 27, 2024

会不会出现inference.py就会得到推理结果很差,但是运行calculate_results.py得到的结果也还行。我想知道calculate_results.py的具体计算公式,论文里有吗

from multi-label-sewer-classification.

qianmeng19 avatar qianmeng19 commented on June 27, 2024

请问博主是使用了所有数据集进行训练吗

from multi-label-sewer-classification.

iamstarlee avatar iamstarlee commented on June 27, 2024

calculate_results.py是基于inference.py的结果计算的,calculate_results.py的计算公式都在metrics.py里,论文的附录里有

from multi-label-sewer-classification.

iamstarlee avatar iamstarlee commented on June 27, 2024

我是用的整个数据集训练的

from multi-label-sewer-classification.

qianmeng19 avatar qianmeng19 commented on June 27, 2024

请问博主试过只使用一个文件train13训练吗,我使用一个文件训练的结果好像不行

from multi-label-sewer-classification.

qianmeng19 avatar qianmeng19 commented on June 27, 2024

请问博主执行calculate_results.py时遇到这个报错是结果有问题吗E:\Multi-label-Sewer-Classification-main\metrics.py:173: RuntimeWarning: invalid value encountered in scalar divide
F2_normal = (5 * precision_k[-1] * recall_k[-1])/(4*precision_k[-1] + recall_k[-1])

from multi-label-sewer-classification.

iamstarlee avatar iamstarlee commented on June 27, 2024

这种看上去像除了0,可以试着print分母的precision和recall

from multi-label-sewer-classification.

iamstarlee avatar iamstarlee commented on June 27, 2024

train13是我从原始数据集中摘出来的子集,可能因为数量少导致效果不好

from multi-label-sewer-classification.

Related Issues (10)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.