8 多输出预测与多标签分类pytorch网络搭建

news/2024/8/25 6:51:52 标签: 分类, pytorch, 数据挖掘

文章目录

  • 前言
  • 一、多输出预测(回归)
    • 1 坐标数据生成
    • 2 网络搭建训练预测
  • 二、多标签分类
    • 1 多标签数据生成
    • 2 网络搭建训练
  • 总结


前言

前面我们搭建的无论是分类还是回归都只能预测一个标签,这显然效果很局限。下面我们想做到下面这两种效果:

  • 多输出预测(回归):例如训练网络拟合北东天坐标转机体坐标的关系,输入是三坐标,输出也是三坐标
  • 多标签分类:例如,输入图像数据,训练网络判断图片里面有猫,有狗,还是只有其中一种这样

【注】:在介绍pytorch的内置损失函数博客中已经介绍了pytorch的损失函数是支持这个功能的。

一、多输出预测(回归)

1 坐标数据生成

# 本示例演示如何使用 PyTorch 实现多标签回归模型。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 构建数据集
# 假设您有一些经纬高度和对应的地心地固坐标的数据
# 这里只是一个示例,您需要根据实际情况准备您自己的数据集
X = np.random.rand(100, 3)  # 100个样本,每个样本有3个特征(经度、纬度、高度)
y = np.random.rand(100, 3)  # 每个样本有3个目标值(地心地固坐标)
print('y:\n',y)

在这里插入图片描述

2 网络搭建训练预测

# 转换数据为 PyTorch 的 Tensor 类型
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)

# 定义模型
class MultiLabelRegressionModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(MultiLabelRegressionModel, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        
    def forward(self, x):
        out = self.fc(x)
        return out

# 初始化模型
input_size = 3   # 输入特征的数量
output_size = 

http://www.niftyadmin.cn/n/5557326.html

相关文章

150个pb网站模板(都是成品网站,上传php空间即可使用),建站必备

一网友提供的150个pb网站模板,其实就是成品网站,上传php空间即可使用,属于建站公司或者建站开发人员必备的资源。 一共150个基于pb的成品网站,基本上都可以找到适应你手头客户需要的一款,简单修改一下即可交活收钱了。…

【Linux】安装PHP扩展-igbinary

说明 本文档是在centos7.6的环境下,安装PHP7.4之后,安装对应的PHP扩展igbinary。 一、igbinary简述 igbinary 是一个 PHP 扩展,主要用于序列化和反序列化数据,其设计目的是为了提高序列化过程中的性能和内存效率。 优点&#…

LeetCode第257题:二叉树的所有路径的Java实现

摘要 LeetCode第257题要求生成二叉树的所有从根节点到叶子节点的路径。本文将介绍两种Java解决方案:迭代法和递归法。 1. 问题描述 给定一个二叉树的根节点,按照从根到叶的顺序遍历所有路径,并将它们作为列表的列表返回。 2. 示例分析 输…

【Qt】之【Bug】MaintenanceTool qt安装组件 无法下载存档

解决 参考:qt更新组件时,提示无法下载存档 进入MaintenanceTool.exe所在目录,使用命令行,镜像源打开程序,进行更新或添加组件 .\MaintenanceTool.exe --mirror https://mirrors.cloud.tencent.com/qt/顺利

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(九)-无人机服务区分离

引言 本文是3GPP TR 22.829 V17.1.0技术报告,专注于无人机(UAV)在3GPP系统中的增强支持。文章提出了多个无人机应用场景,分析了相应的能力要求,并建议了新的服务级别要求和关键性能指标(KPIs)。…

【人工智能】-- 迁移学习

个人主页:欢迎来到 Papicatch的博客 课设专栏 :学生成绩管理系统 专业知识专栏: 专业知识 文章目录 🍉引言 🍉迁移学习 🍈基本概念 🍍定义 🍌归纳迁移学习(Induct…

【qt】正则表达式来判断是否为邮箱登录

正则表达式是用来匹配字符串的神器. 在Qt中我们需要使用到QRegExp这个类 用exactMatch来进行匹配. [] 使用方括号 [] 来定义字符类,表示匹配方括号内的任意一个字符 A-Za-z0-9是字符的匹配范围. 是用于指定字符或字符类出现的次数,常见的如下 *(匹配 0…

感应灯光画纯电路开源版本

前言 之前那版灯光画用的从垃圾佬淘的电路板拼出来的,功能不全,显示效果不太好而且无法固定到相框上,这次改版用的嘉立创smt,贴了5片板子(19元),功能上的改进是加了无极触摸调光、添加了黄白两…