以安全正确的方式使用RandomForestClassifier的predict_proba()函数。

我正在使用Scikit-learn在我的数据集上应用机器学习算法。有时我需要标签类的概率,而不是标签类本身。我不希望用SpamNot Spam作为邮件的标签,而是希望只用0.78的概率来表示某封邮件是Spam。

为了达到这个目的,我使用了 predict_proba() 与RandomForestClassifier如下。

clf = RandomForestClassifier(n_estimators=10, max_depth=None,
    min_samples_split=1, random_state=0)
scores = cross_val_score(clf, X, y)
print(scores.mean())

classifier = clf.fit(X,y)
predictions = classifier.predict_proba(Xtest)
print(predictions)

我得到了这些结果。

 [ 0.4  0.6]
 [ 0.1  0.9]
 [ 0.2  0.8]
 [ 0.7  0.3]
 [ 0.3  0.7]
 [ 0.3  0.7]
 [ 0.7  0.3]
 [ 0.4  0.6]

其中第二列是类: 垃圾邮件. 然而,我对结果有两个主要问题,我对这些结果没有信心。第一个问题是,结果代表了标签的概率,而没有受到我的数据大小的影响?第二个问题是结果只显示一个数字,在某些情况下,0.701的概率和0.708的概率相差很大,这不是很具体。比如说有什么办法可以得到下一个5位数吗?

解决方案:

  1. 在我的结果中,我得到了超过一个数字,你确定这不是由于你的数据集吗? (例如,使用一个非常小的数据集将产生简单的决策树,所以 “简单 “的概率)。否则可能只是显示一个数字,但是试着打印出 predictions[0,0].

  2. 我不太明白你说的 “概率不受数据大小的影响 “是什么意思。如果您担心的是您不想预测,例如,太多的垃圾邮件,通常做的是使用一个阈值的 t 使你预测1,如果 proba(label==1) > t. 这样你就可以使用阈值来平衡你的预测,例如限制垃圾邮件的全局概率。如果您想在全球范围内分析您的模型,我们通常会计算接收者操作特征(ROC)曲线下的面积(AUC)(见维基百科文章《全球垃圾信息的概率》)。此处). 基本上,ROC曲线是对你的预测的描述,它取决于阈值 t.

希望对大家有所帮助!

给TA打赏
共{{data.count}}人
人已打赏
未分类

如何从.net核心工程中运行 "dotnet myDll.dll"?

2022-9-9 7:53:19

未分类

mapKit苹果。地名和用户没有焦点地图的错误。

2022-9-9 7:53:21

0 条回复 A文章作者 M管理员
    暂无讨论,说说你的看法吧
个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索