数据处理

版本1
#数据处理
import os
import torch
from torch.utils import data
from PIL import Image
import numpy as np
#定义自己的数据集合
class DogCat(data.Dataset):
  def __init__(self,root):
    #所有图片的绝对路径
    imgs=os.listdir(root)
    self.imgs=[os.path.join(root,k) for k in imgs]
  def __getitem__(self, index):
    img_path=self.imgs[index]
    #dog-> 1 cat ->0
    label=1 if 'dog' in img_path.split('/')[-1] else 0
    pil_img=Image.open(img_path)
    array=np.asarray(pil_img)
    data=torch.from_numpy(array)
    return data,label
  def __len__(self):
    return len(self.imgs)
dataSet=DogCat('./data/dogcat')
print(dataSet[0])
            Copyright © 2009-2022 www.wtcwzsj.com 青羊区广皓图文设计工作室(个体工商户) 版权所有 蜀ICP备19037934号