Python中metrics干什么的_python里面mean
chenpack 2025-05-25 13:30 479 浏览 0 评论
作者:PyTorch Lightning team
编译:ronghuaiyang
导读
非常简单实用的PyTorch模型的分布式指标度量库,配合PyTorch Lighting实用更加方便。
找出你需要评估的指标是深度学习的关键。有各种各样的指标,我们可以评估ML算法的性能。TorchMetrics是一个PyTorch度量的实现的集合,是PyTorch Lightning高性能深度学习的框架的一部分。在本文中,我们将介绍如何使用TorchMetrics评估你的深度学习模型,甚至使用一个简单易用的API创建你自己的度量。
什么是TorchMetrics?
TorchMetrics是一个开源的PyTorch原生的函数和度量模块的集合,用于简单的性能评估。你可以使用开箱即用的实现来实现常见的指标,如准确性,召回率,精度,AUROC, RMSE, R2等,或者创建你自己的指标。我们目前支持超过个指标,并不断增加更多的通用任务和特定领域的标准(目标检测,NLP等)。
TorchMetrics最初是作为Pytorch Lightning (PL)的一部分创建的,被设计为分布式硬件兼容,并在默认情况下与DistributedDataParalel(DDP)一起工作。所有指标都在cpu和gpu上经过严格测试。
使用TorchMetrics
安装
这个包可以通过以下方式从PyPI简单安装:
pip install torchmetrics
或者直接从GitHub仓库的源代码安装:
# with git
pip install git+https://github.com/PytorchLightning/metrics.git@master
函数形式的metrics
类似于torch.nn,大多数度量指标都有基于模块和函数的版本。函数版本实现了计算每个度量所需的基本操作。它们是作为输入的简单的python函数。并返回相应的torch.tensor的指标。下面的代码片段展示了一个使用函数接口计算精度的简单示例:
模块形式的metrics
几乎所有函数metrics都有一个对应的基于模块的metrics,该度量将其称为底层的函数等价模块。基于模块的度量的特点是有一个或多个内部度量状态(类似于PyTorch模块的参数),允许它们提供额外的功能:
- 多批次积累
- 多台设备间自动同步
- 度量算法
下面的代码展示了如何使用基于模块的接口:
每次调用度量的forward函数时,我们同时计算当前看到的一批数据上的度量值,并更新内部度量状态,以跟踪到目前为止看到的所有数据。内部状态需要在不同时期之间重置,不应该在训练、验证和测试之间混合。因此我们强烈建议按如下方式重新初始化度量:
Lightning中使用TorchMetrics
下面的例子展示了如何在你的LightningModule中使用metric :
虽然TorchMetrics被构建为与原生的PyTorch一起使用,但TorchMetrics与Lightning一起使用提供了额外的好处:
- 当在LightningModule中正确定义模块metrics 时,模块metrics会自动放置在正确的设备上。这意味着你的数据将始终与你的metrics 放在相同的设备上。
- 在Lightning中支持使用原生的self.log,Lightning会根据on_step 和on_epoch标志来记录metric,如果on_epoch=True,logger 会在epoch结束的时候自动调用.compute()。
- metric 的.reset()方法的度量在一个epoch结束后自动被调用。
Lightning的转换
已经熟悉Lightning的metric接口的用户应该能够轻松地适应TorchMetrics。简单地替换:
from pytorchlightning import metrics
with:
import torchmetrics
注意,在版本之前,metrics将是PyTorchLightning的一部分,但不再接收任何更新。我们强烈建议用户切换到TorchMetrics,以得到我们可能实现的所有的bug修复和增强。
实现自己的metrics
如果你想使用一个还不被支持的指标,你可以使用TorchMetrics的API来实现你自己的自定义指标,只需子类化torchmetrics.Metric并实现以下方法:
- __init__():每个状态变量都应该使用self.add_state(…)调用。
- update():任何需要更新内部度量状态的代码。
- compute():从度量值的状态计算一个最终值。
例子:均方根误差
均方根误差是一个很好的例子,说明了为什么许多度量计算需要划分为两个函数。定义为:
为了正确地计算RMSE,我们需要两个度量状态:sum_squared_error来跟踪目标y和预测y之间的平方误差,以及n_observations来知道我们有多少观测结果。
因为sqrt(a+b) != sqrt(a) + sqrt(b),我们不能把这个度量实现为每个batch计算的RMSE分数的简单平均值,而是需要实现更新步骤中需要在平方根之前发生的所有逻辑,以及在compute步骤中需要实现剩余的逻辑。
为你的模型选择正确的度量
选择正确的度量对于确定你的模型是否按照应该的方式运行,或者是否有什么地方出了问题非常重要。
预测冠状病毒
假设你的任务是建立一个分类网络,可以通过一套非侵入性测量来确定患者是否是冠状病毒阳性。你会得到数千份观察报告,并使用你最喜欢的网络架构,优化以正确识别哪些患者感染了冠状病毒。这种模式可用于确保检测呈阳性的患者被隔离,以避免传播病毒并迅速得到治疗。
为了评估你的模型,你计算了4个指标:准确性、混淆矩阵、精确度和召回率。你得到了以下结果:
准确率: %
混淆矩阵:
精确率:
召回率:
评估得分
你怎么看?这个模型足够好吗?让我们更深入地了解这些指标的含义。在分类中,准确率是指我们的模型得到正确预测的比例。
我们的模型得到了非常高的准确率:%。看来网络正在做你要求它做的事情,你可以准确地检测到患者是否感染了冠状病毒。
对于二元分类,另一个有用的度量是混淆矩阵,这给了我们下面的真、假阳性和阴性的组合。
我们可以从混淆矩阵中快速确定两件事:
- 阴性患者的数量远远少于阳性患者的数量 —> 这意味着你的数据集是高度不平衡的。
- 有5名患者检测失败
从准确性来看,这个模型似乎表现得很好,但考虑到混淆矩阵,我们发现这个模型过于专注于预测阴性患者,而未能预测阳性患者。在这种设置下,它应该清楚正确识别新冠患者和正确识别非新冠患者之间的巨大的区别,正确识别患者将确保患者得到早期治疗,最重要的是隔离,不要传染给别人。
为什么准确率指标没有显示出模型有什么问题?准确率捕获了整体性能,以正确地预测所有类,在这种情况下,我们感兴趣的是捕获我们预测的ground truth的情况有多好。因此,你可以将注意力转向精确率和召回率。
精确率定义为实际正确的正样本的比例。
其中TP和FP分别表示true p positive个数,false positive个数。一个有0个误报的模型的精确率为,而一个模型输出的结果都是阳性,而实际上都是假的模型的精度分数为0。
Recall定义为真实的阳性被正确识别的比例。
其中TP和FN分别表示true positives数,false negatives数。类似地,如果没有错误否定,一个模型的召回分数将为。从定义上我们可以得出结论,精确率聚焦于在不能识别所有假阳性的“成本”上,而召回率聚焦在不能识别所有假阴性的“成本”上。因为我们在这里感兴趣的是假阴性,所以我们应该在recall metric下重新评估我们的模型,现在我们得到了的分数。现在,你已经量化了模型的性能不佳,并且在训练机器学习算法时可能需要处理数据集中存在的巨大类不平衡。
这个小例子展示了选择正确度量来评估机器学习算法的重要性。通常,建议使用一组度量标准来评估算法,因为它们都关注数据和模型预测的不同方面。
—END—
英文原文:
https://pytorch-lightning.medium.com/torchmetrics-pytorch-metrics-built-to-scale-7091b1bec919
相关推荐
- printf使用详解_printfi
-
C语言的调试利器-printf大法,无坚不摧,攻无不破程序输入与输出当我们操作一个linux终端的时候,执行linux命令程序,可以看到命令的输出信息,或者要求输入数据。那么,这些操作就是lin...
- 在java项目中怎么查看maven版本
-
你还不知道java程序maven打包后如何查看jar的编译版本,过来看看有时候我们需要知道一个jar是编译是依赖JAVA哪个版本的需求。如:我们知道该程sjjcb-dev-lyq-example-...
- mysql多行合并成一行_mysql多行拼接
-
多行合并成一行sql函数group_concat和stuff一、MySQL中group_concat函数完整的语法如下:group_concat([DISTINCT]要连接的字段[OrderB...
- pythonfor循环求1!+2!+3!+....+n!的和 python用循环求1到100的和
-
python经典案例:求1到之和问题:求1到之和方法1:利用循环求和#!/usr/bin/python#coding:utf-8#author:菜就多练呀total=0foriinran...
- python批量查询备案号_python 批量查询
-
批量查询ip对应域名、备案信息、百度权重ip2domain-批量查询ip对应域名、备案信息、百度权重本工具二开自https://github.com/sma11new/ip2domain在批量挖S...
- motionbuilder镜像动画 motionbuilder插件
-
如果把《哪吒2》制作全部交给AI《哪吒2》作为现象级动画电影,其成功确实依赖于产业链的高效协作。随着AI技术逐渐渗透动画制作全流程,未来动画产业将呈现"AI全链参与+人类创意主导"的深度融合模式。以下...
- 逻辑运算符、算术运算符、赋值运算符等等归纳
-
运算符(算术、比较、逻辑等)1、算数运算符如:a=,b=+加如:a+b=-减如:a-b=-乘如:a*b=/除如:a/b=%取余如:a/b=**幂x**y返回x的y次方,如...
- 如何生成HTML5页面代码_如何用html5制作一个网页
-
vscode快速生成html代码技巧快速生成Html5骨架在Html文件中输入html:5按下回车键,可快速生成HTML5页面模板:Docu...
- 国家海洋局第二海洋研究所(海洋二所)考研答疑
-
上海海洋大学发布年硕士研究生招生章程,来看详情→近日,上海海洋大学发布年硕士研究生招生章程一起来看看吧!一、学校简介上海海洋大学建校于年,是一所多科性应用研究型大学,是上海市人民政府与国家海洋局、农业...
- vue-awesome-swiper轮播图实现
-
swiper在vue中正确的使用方法swiper是网页中非常强大的一款轮播插件,说是轮播插件都不恰当,因为它能做的事情太多了,swiper在vue下也是能用的,需要依赖专门的vue-swiper插件,...
- yarn下载安装教程_yarn安装axios
-
yarn的安装和使用一、安装在windows下(1)下载node.js,使用npm安装npminstall-gyarn(2)查看版本yarn--versionYarn淘宝源安装&nbs...
- freemodbus 主机源码 freemodbus stm32
-
STM32单片机移植FreeModbus详细过程modbus是一个非常好的串口协议(当然也能用在网口上),它简洁、规范、强大。可以满足大部分的工业、嵌入式需求。这里详细说下如何将freemodbus...
- python if 多条件并列判断_python多个if并列怎么运行
-
pythonif多条件并列判断的三种方法pythonif多条件并列判断的三种方法如果使用python的if进行多个条件表达式的判断呢?下面介绍三种方法:使用and或or来连接多个条件表达式,比如条...
- driver memory和executor memory怎么设置 memory odd ratio怎么设置
-
RocketMQ原理—2.源码设计简单分析一大纲1.NameServer的启动脚本2.NameServer启动时会解析哪些配置3.NameServer如何初始化Netty网络服务器4.NameServ...
- FPGA编程架构_fpga的编程语言是什么
-
深入浅出带你了解FPGA架构数字集成电路有两种类型:ASIC和FPGA(现场可编程门阵列)。专用集成电路(ASIC)有一个预先定义的特定硬件功能,在生产后不能重新编程。但FPGA可以在制造后可无限编程...
你 发表评论:
欢迎- 一周热门
-
-
维基百科Wikipedia镜像网站列表
-
超炫html+css+javascript幻化3D相册 (含背景音乐)程序员表白必备
-
不能读取文件“itunes.library.itl”因为它是由更高级别的itunes所创建的
-
6款图片查看器,丝滑干净无广告!(图片查看器软件)
-
用java编写一个QQ群发信息_用java语言写qq聊天程序
-
StreamReader StringReader 区别 reader和inputstream的区别
-
Windows Server 2003 详细安装与配置
-
作为一名独立开发者,我是如何建立我的科技创业公司的
-
计算机集成制造系统有哪些_计算机集成制造系统有哪些类型
-
虚拟化测评 虚拟化测试方案_虚拟化测试工程师招聘
-
- 最近发表
- 标签列表
-
- int.tryparse (62)
- list转list (108)
- repeat函数 (66)
- git force (69)
- springboot /error (71)
- mysql 更新 (74)
- save as pdf (63)
- lock tables (66)
- 同步 异步 阻塞 非阻塞 (62)
- rsyslog (66)
- querystring (63)
- c++ override (70)
- css 动画库 (61)
- vsphere web client (65)
- int32_t (63)
- c# task.run (68)
- find -size (64)
- golang flag包 (70)
- 二维数组作为参数传入函数 (62)
- sudo su root (60)
- crontab 安装 (61)
- c# 数组转成list (60)
- 下拉按钮 (64)
- 滚动条美化 (61)
- stringutils (61)