关于pytorch的加载数据,cpu init, cpu getitem, gpu init

news/2024/7/8 5:01:54 标签: pytorch, 人工智能, python

文章目录

    • 一. (cpu,init)图像加载到CPU内存,是在 __init__中函数中全部数据, 然后在item中取图像
    • 二.(cpu,get_item)是图像在 get_item函数中,载入图像到CPU
    • 三(gpu,init)是将图像加载到GPU, 在init函数中

跑多光谱估计的代码,参考:https://github.com/caiyuanhao1998/MST-plus-plus
原代码dataset一次加载所有图像到cpu内存中

一. (cpu,init)图像加载到CPU内存,是在 __init__中函数中全部数据, 然后在item中取图像

这种方法比较常用,读取图像的效率也高,但是cpu内存要够

python">from torch.utils.data import Dataset
import numpy as np
import random
import cv2
import h5py
import torch
class TrainDataset(Dataset):
    def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):
        self.crop_size = crop_size
        self.hypers = []
        self.bgrs = []
        self.arg = arg
        h,w = 482,512  # img shape
        self.stride = stride
        self.patch_per_line = (w-crop_size)//stride+1
        self.patch_per_colum = (h-crop_size)//stride+1
        self.patch_per_img = self.patch_per_line*self.patch_per_colum

        hyper_data_path = f'{data_root}/Train_spectral/'
        bgr_data_path = f'{data_root}/Train_RGB/'

        with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:
            hyper_list = [line.replace('\n','.mat') for line in fin]
            bgr_list = [line.replace('mat','jpg') for line in hyper_list]
        hyper_list.sort()
        bgr_list.sort()
        
        # hyper_list = hyper_list[:300]
        # bgr_list = bgr_list[:300]
        print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')
        print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')
        for i in range(len(hyper_list)):
            hyper_path = hyper_data_path + hyper_list[i]
            if 'mat' not in hyper_path:
                continue
            with h5py.File(hyper_path, 'r') as mat:
                hyper =np.float32(np.array(mat['cube']))
            hyper = np.transpose(hyper, [0, 2, 1])
            bgr_path = bgr_data_path + bgr_list[i]
            assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
            bgr = cv2.imread(bgr_path)
            if bgr2rgb:
                bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
            bgr = np.float32(bgr)
            bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
            bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]
            self.hypers.append(hyper)
            self.bgrs.append(bgr)
            mat.close()
            print(f'Ntire2022 scene {i} is loaded.')
        self.img_num = len(self.hypers)
        self.length = self.patch_per_img * self.img_num

    def arguement(self, img, rotTimes, vFlip, hFlip):
        # Random rotation
        for j in range(rotTimes):
            img = np.rot90(img.copy(), axes=(1, 2))
        # Random vertical Flip
        for j in range(vFlip):
            img = img[:, :, ::-1].copy()
        # Random horizontal Flip
        for j in range(hFlip):
            img = img[:, ::-1, :].copy()
        return img

    def __getitem__(self, idx):
        stride = self.stride
        crop_size = self.crop_size
        img_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_img
        h_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_line
        bgr = self.bgrs[img_idx]
        hyper = self.hypers[img_idx]
        bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]
        hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]
        rotTimes = random.randint(0, 3)
        vFlip = random.randint(0, 1)
        hFlip = random.randint(0, 1)
        if self.arg:
            bgr = self.arguement(bgr, rotTimes, vFlip, hFlip)
            hyper = self.arguement(hyper, rotTimes, vFlip, hFlip)
        return np.ascontiguousarray(bgr), np.ascontiguousarray(hyper)

    def __len__(self):
        return self.patch_per_img*self.img_num

二.(cpu,get_item)是图像在 get_item函数中,载入图像到CPU

这种方法可以处理大数据集,比如所有图像占用内存大于电脑内存的时候,用这种方法
但是由于读取图像放在了get_item中,训练的时候加载数据会比较慢。

python">class TrainDataset_single(Dataset):
    def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):
        self.crop_size = crop_size
        self.hypers = []
        self.bgrs = []
        self.arg = arg
        self.bgr2rgb = bgr2rgb
        h,w = 482,512  # img shape
        self.stride = stride
        self.patch_per_line = (w-crop_size)//stride+1
        self.patch_per_colum = (h-crop_size)//stride+1
        self.patch_per_img = self.patch_per_line*self.patch_per_colum

        hyper_data_path = f'{data_root}/Train_spectral/'
        bgr_data_path = f'{data_root}/Train_RGB/'

        with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:
            hyper_list = [line.replace('\n','.mat') for line in fin]
            bgr_list = [line.replace('mat','jpg') for line in hyper_list]
        hyper_list.sort()
        bgr_list.sort()
        
        # hyper_list = hyper_list[:300]
        # bgr_list = bgr_list[:300]
        print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')
        print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')
        for i in range(len(hyper_list)):
            hyper_path = hyper_data_path + hyper_list[i]
          
            bgr_path = bgr_data_path + bgr_list[i]
            
            # if 'mat' not in hyper_path:
            #     continue
            # with h5py.File(hyper_path, 'r') as mat:
            #     hyper =np.float32(np.array(mat['cube']))
            # hyper = np.transpose(hyper, [0, 2, 1])
            # assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
            # bgr = cv2.imread(bgr_path)
            # if bgr2rgb:
            #     bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
            # bgr = np.float32(bgr)
            # bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
            # bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]
            self.hypers.append(hyper_path)
            self.bgrs.append(bgr_path)
            # mat.close()
            print(f'Ntire2022 scene {i} is loaded.')
        self.img_num = len(self.hypers)
        self.length = self.patch_per_img * self.img_num

    def arguement(self, img, rotTimes, vFlip, hFlip):
        # Random rotation
        for j in range(rotTimes):
            img = np.rot90(img.copy(), axes=(1, 2))
        # Random vertical Flip
        for j in range(vFlip):
            img = img[:, :, ::-1].copy()
        # Random horizontal Flip
        for j in range(hFlip):
            img = img[:, ::-1, :].copy()
        return img

    def __getitem__(self, idx):
        stride = self.stride
        crop_size = self.crop_size
        img_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_img
        h_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_line
        bgr_path = self.bgrs[img_idx]
        hyper_path = self.hypers[img_idx]
        
        # if 'mat' not in hyper_path:
        #     continue
        with h5py.File(hyper_path, 'r') as mat:
            hyper =np.float32(np.array(mat['cube']))
        hyper = np.transpose(hyper, [0, 2, 1])
        # assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
        bgr = cv2.imread(bgr_path)
        if self.bgr2rgb:
            bgr = bgr[..., ::-1] #cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        bgr = np.float32(bgr)
        bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
        bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]
            
        bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]
        hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]
        rotTimes = random.randint(0, 3)
        vFlip = random.randint(0, 1)
        hFlip = random.randint(0, 1)
        if self.arg:
            bgr = self.arguement(bgr, rotTimes, vFlip, hFlip)
            hyper = self.arguement(hyper, rotTimes, vFlip, hFlip)
        return np.ascontiguousarray(bgr), np.ascontiguousarray(hyper)

    def __len__(self):
        return self.patch_per_img*self.img_num

三(gpu,init)是将图像加载到GPU, 在init函数中

就是cpu内存不够不能使用方法一,且我们不像速度太慢不能使用方法二。
如果GPU显存比较大的时候,或者有多个GPU的时候,可以在init函数中将图像读取到若干个GPU中。

比如下面,将450张读取到gpu0, 另外450张读取到gpu1
这样TrainDataset_gpu[i] 返回的就是在gpu上的数据

python">"""
数据在不同的gpu上,不能使用dataloader
"""    
class TrainDataset_gpu(Dataset):
    def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):
        self.crop_size = crop_size
        self.hypers = []
        self.bgrs = []
        self.arg = arg
        self.bgr2rgb = bgr2rgb
        h,w = 482,512  # img shape
        self.stride = stride
        self.patch_per_line = (w-crop_size)//stride+1
        self.patch_per_colum = (h-crop_size)//stride+1
        self.patch_per_img = self.patch_per_line*self.patch_per_colum

        hyper_data_path = f'{data_root}/Train_spectral/'
        bgr_data_path = f'{data_root}/Train_RGB/'

        with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:
            hyper_list = [line.replace('\n','.mat') for line in fin]
            bgr_list = [line.replace('mat','jpg') for line in hyper_list]
        hyper_list.sort()
        bgr_list.sort()
        
        # hyper_list = hyper_list[:100]
        # bgr_list = bgr_list[:100]
        print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')
        print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')
        for i in range(len(hyper_list)):
            hyper_path = hyper_data_path + hyper_list[i]
          
            bgr_path = bgr_data_path + bgr_list[i]
            
            if 'mat' not in hyper_path:
                continue
            with h5py.File(hyper_path, 'r') as mat:
                hyper =np.float32(np.array(mat['cube']))
            hyper = np.transpose(hyper, [0, 2, 1])
            assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
            bgr = cv2.imread(bgr_path)
            if bgr2rgb:
                bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
            bgr = np.float32(bgr)
            bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
            bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]
            
            if i < 450:
                device = torch.device('cuda:0')
                self.hypers.append(torch.from_numpy(hyper).to(device))
                self.bgrs.append(torch.from_numpy(bgr).to(device))
            elif i<900:
                device = torch.device('cuda:1')
                self.hypers.append(torch.from_numpy(hyper).to(device))
                self.bgrs.append(torch.from_numpy(bgr).to(device))
            # mat.close()
            print(f'Ntire2022 scene {i} is loaded.')
        self.img_num = len(self.hypers)
        self.length = self.patch_per_img * self.img_num

    def arguement(self, img, hyper, rotTimes, vFlip, hFlip):
         # Random rotation
        if rotTimes:
            img = torch.rot90(img, rotTimes, [1, 2])
            hyper = torch.rot90(hyper, rotTimes, [1, 2])
        # Random vertical Flip
        if vFlip:
            #img = img[:, :, ::-1]
            img = torch.flip(img, dims=[1])
            hyper = torch.flip(hyper, dims=[1])
        # Random horizontal Flip
        if hFlip:
            #img = img[:, ::-1, :]
            img = torch.flip(img, dims=[2])
            hyper = torch.flip(hyper, dims=[2])
        return img, hyper

    def __getitem__(self, idx):
        stride = self.stride
        crop_size = self.crop_size
        img_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_img
        h_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_line
        bgr = self.bgrs[img_idx]
        hyper = self.hypers[img_idx]
        
        bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]
        hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]
        rotTimes = random.randint(0, 3)
        vFlip = random.randint(0, 1)
        hFlip = random.randint(0, 1)
        if self.arg:
            bgr, hyper = self.arguement(bgr, hyper, rotTimes, vFlip, hFlip)
        
        return bgr, hyper # np.ascontiguousarray(bgr.cpu().numpy()), np.ascontiguousarray(hyper.cpu().numpy()) 

    def __len__(self):
        return self.patch_per_img*self.img_num

但是读取到GPU之后,训练的时候 好像不能使用dataloader, 容易报错。

这个时候自己设计一个 批处理函数,和shuffle

python"># 1.加载数据集
train_data = TrainDataset_gpu(data_root=opt.data_root, crop_size=opt.patch_size, bgr2rgb=True, arg=True, stride=opt.stride)

# 2. 获取数据集的长度, 使是batch_size的倍数, 打乱顺序
inddd = np.arange(len(train_data))
l =len(inddd) -  (len(inddd)%opt.batch_size) 
inddd2 = np.random.permutation(inddd)[:l]
inddd2 = inddd2.reshape(-1, opt.batch_size) #batch num, batch size
print(len(train_data), len(inddd)%opt.batch_size, inddd2.shape)

# 3. 读取每一个batch的图像
for i in range(inddd2.shape[0]):
    t0 = time.time()
    # 检索batch size个图像拼接为一个batch
    inddd3 = inddd2[i]
    #print('i, len, curlist:',i, len(inddd2), inddd3)
    images = []
    labels = []
    for j in inddd3:
        image, label = train_data[j]
        image = image[None, ...]
        label = label[None, ...]
        # print(i, j, image.shape, label.shape)
        # cv2.imwrite(f'{i:9d}_{j:4d}_image.png', (image[0].cpu().numpy().transpose(1,2,0)[...,[2,1,0]]*255).astype(np.uint8))
        # cv2.imwrite(f'{i:9d}_{j:4d}_label.png', (label[0].cpu().numpy().transpose(1,2,0)[...,[5,15,25]]*255).astype(np.uint8))
        images.append(image.cpu())
        labels.append(label.cpu())
    images = torch.cat(images, 0)
    labels = torch.cat(labels, 0)
    #print(images.shape, labels.shape)
    
    labels = labels.cuda()
    images = images.cuda()

http://www.niftyadmin.cn/n/5536490.html

相关文章

QT+OpenCV在Android上实现人脸实时检测与目标检测

一、功能介绍 在当今的移动应用领域&#xff0c;随着技术的飞速发展和智能设备的普及&#xff0c;将先进的计算机视觉技术集成到移动平台&#xff0c;特别是Android系统中&#xff0c;已成为提升用户体验、拓展应用功能的关键。其中&#xff0c;目标检测与人脸识别作为计算机视…

内存生产全速推进:产能逼近峰值,超越成熟节点晶圆厂

随着内存价格和需求的增长&#xff0c;内存制造商南亚科技和华邦电子已经恢复了正常生产&#xff0c;不再像去年那样减产。根据自由时报网络引述集邦咨询和业内消息来源的报告&#xff0c;内存出货量在第三季度将持续复苏。 据报道&#xff0c;内存制造商的产能利用率已达到90…

推荐 2个功能强大的黑科技工具,真的会让你直呼卧槽

Waifu2X Waifu2x 是一个基于深度学习的开源项目&#xff0c;主要用于处理二次元动漫风格的图像。它使用卷积神经网络&#xff08;CNN&#xff09;进行超分辨率处理和降噪&#xff0c;能够将图像放大2倍或更多&#xff0c;同时显著提高清晰度和减少噪声。Waifu2x 特别针对日系漫…

JAVA案例模拟电影信息系统

一案例要求&#xff1a; 二具体代码(需要在同一个包下创建三个类) Ⅰ&#xff1a;实现类 package 重修;import java.util.Random; import java.util.Scanner;public class first {public static void main(String[] args) {javabean[]moviesnew javabean[4];movies[0] new ja…

【算法】(C语言):冒泡排序、选择排序、插入排序

冒泡排序 从第一个数据开始到第n-1个数据&#xff0c;依次和后面一个数据两两比较&#xff0c;数值小的在前。最终&#xff0c;最后一个数据&#xff08;第n个数据&#xff09;为最大值。从第一个数据开始到第n-2个数据&#xff0c;依次和后面一个数据两两比较&#xff0c;数值…

ubuntu20.04换源

一、概述 重新在联想电脑上安装ubuntu20.04系统后&#xff0c;在安装ROS过程中&#xff0c;出现了不少问题&#xff0c;其中在使用下面命令时候&#xff0c;发现如下问题。 sudo apt-get update 使用update更新当前所安装软件版本时候&#xff0c;发现报出错误&#xff0c;无法…

Vue组件化、单文件组件以及使用vue-cli(脚手架)

文章目录 1.Vue组件化1.1 什么是组件1.2 组件的使用1.3 组件的名字1.4 嵌套组件 2.单文件组件2.1 vue 组件组成结构2.1.1 template -> 组件的模板结构2.1.2 组件的 script 节点2.1.3 组件的 style 节点 2.2 Vue组件的使用步骤2.2.1 组件之间的父子关系2.2.2 使用组件的三个步…

倘若你的的B端系统如此漂亮,还担心拿不出手吗,尤其是面对客户

如果你的B端系统设计如此漂亮&#xff0c;那么通常来说&#xff0c;你不太需要担心在客户那里拿不出手。一个漂亮和易用的设计可以提升用户体验&#xff0c;增加客户对系统的满意度。 然而&#xff0c;还是有一些因素需要考虑&#xff0c;以确保你的B端系统在客户那里能够得到良…