本文简要介绍python语言中 torchvision.models.detection.keypointrcnn_resnet50_fpn
的用法。
用法:
torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=2, num_keypoints=17, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs)
pretrained(bool) -如果为 True,则返回在 COCO train2017 上预训练的模型
progress(bool) -如果为 True,则显示下载到 stderr 的进度条
num_classes(int) -模型的输出类数(包括背景)
num_keypoints(int) -关键点数,默认 17
pretrained_backbone(bool) -如果为 True,则返回一个在 Imagenet 上预训练过主干的模型
trainable_backbone_layers(int) -从最终块开始的可训练(未冻结)resnet 层数。有效值介于 0 和 5 之间,其中 5 表示所有主干层都是可训练的。
使用 ResNet-50-FPN 主干构造 Keypoint R-CNN 模型。
参考:“Mask R-CNN”。
模型的输入应该是一个张量列表,每个形状为
[C, H, W]
,每个图像一个,并且应该在0-1
范围内。不同的图像可以有不同的尺寸。模型的行为取决于它是处于训练模式还是评估模式。
在训练期间,模型需要输入张量以及目标(字典列表),其中包含:
框 (
FloatTensor[N, 4]
):[x1, y1, x2, y2]
格式的 ground-truth 框,包含0 <= x1 < x2 <= W
和0 <= y1 < y2 <= H
。labels (
Int64Tensor[N]
):每个ground-truth框的类标签关键点 (
FloatTensor[N, K, 3]
):每个N
实例的K
关键点位置,格式为[x, y, visibility]
,其中visibility=0
表示关键点不可见。
该模型在训练期间返回
Dict[Tensor]
,包含 RPN 和 R-CNN 的分类和回归损失,以及关键点损失。在推理过程中,模型只需要输入张量,并将后处理的预测作为
List[Dict[Tensor]]
返回,每个输入图像一个。Dict
的字段如下,其中N
是检测到的实例数:框 (
FloatTensor[N, 4]
):[x1, y1, x2, y2]
格式的预测框,包含0 <= x1 < x2 <= W
和0 <= y1 < y2 <= H
。labels (
Int64Tensor[N]
):每个实例的预测标签分数 (
Tensor[N]
):每个实例的分数关键点 (
FloatTensor[N, K, 3]
):预测关键点的位置,采用[x, y, v]
格式。
有关输出的更多详细信息,您可以参考实例分割模型。
关键点 R-CNN 可导出到 ONNX 以用于固定批量大小,输入图像大小固定。
例子:
>>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) >>> >>> # optionally, if you want to export the model to ONNX: >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
使用
keypointrcnn_resnet50_fpn
的示例:
参数:
相关用法
- Python PyTorch kron用法及代码示例
- Python PyTorch kaiming_normal_用法及代码示例
- Python PyTorch kthvalue用法及代码示例
- Python PyTorch kaiming_uniform_用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
- Python PyTorch cholesky用法及代码示例
- Python PyTorch vdot用法及代码示例
- Python PyTorch ELU用法及代码示例
- Python PyTorch ScaledDotProduct.__init__用法及代码示例
- Python PyTorch gumbel_softmax用法及代码示例
- Python PyTorch get_tokenizer用法及代码示例
- Python PyTorch saved_tensors_hooks用法及代码示例
- Python PyTorch positive用法及代码示例
- Python PyTorch renorm用法及代码示例
- Python PyTorch AvgPool2d用法及代码示例
- Python PyTorch MaxUnpool3d用法及代码示例
- Python PyTorch Bernoulli用法及代码示例
- Python PyTorch Tensor.unflatten用法及代码示例
- Python PyTorch Sigmoid用法及代码示例
- Python PyTorch Tensor.register_hook用法及代码示例
- Python PyTorch ShardedEmbeddingBagCollection.named_parameters用法及代码示例
- Python PyTorch sqrt用法及代码示例
- Python PyTorch PackageImporter.id用法及代码示例
- Python PyTorch column_stack用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchvision.models.detection.keypointrcnn_resnet50_fpn。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。