-pre install

1
2
3
4
pip install torch
pip install opencv-python
pip install pandas
pip install torchvision
PyTorch is a powerful deep learning framework that offers many tools and libraries to simplify the process of model creation and training. Its advantages include dynamic computation graphs, an easy-to-use API, and a vibrant community support. To install PyTorch, we can use the following pip command:
1
2
3
4
5
6
import torch
import cv2
import os
import pandas as pd
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

定義一個自定義的 Dataset 類別:
定義了一個叫做 custom 的類別,這個類別繼承了 PyTorch 的 Dataset 類別。

1
class custom(Dataset):

初始化方法 __init__(self, source_root, label_root): 這個方法會在創建類別的對象時被調用。它接受兩個參數,分別是圖片文件的路徑(source_root)和標籤文件的路徑(label_root)。它讀取了標籤文件並將其存儲在 self.label_df 中。
1
2
3
4
5
#Initialize paths
def __init__(self, source_root, label_root):
super().__init__()
self.source_root = source_root
self.label_df = pd.read_csv(label_root)

這個方法根據指定的索引返回一個數據樣本和它的標籤。它首先從標籤文件中讀取指定索引的圖片名稱和標籤,然後讀取並處理對應的圖片文件,最後返回這些信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#Read from file(opencv, pandas, ...) and Return the data (e.g. image and label)
def __getitem__(self, index):
# read csv from dataframe->(file name/label)
img_filename=self.label_df.iat[index,0]
img_label=self.label_df.iat[index,1]
img_path=os.path.join(self.source_root,img_filename)

# img normalize
img=cv2.imread(img_path)
img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB) # color transform (BGR to RGB)
img=torch.tensor(img).float()/255 # img-> 0~1
img=img.permute(2,0,1).contiguous()
img=transforms.Resize((256,256))(img) # img size adjustment
img=transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))(img) # (mean, std)

# return values
return img_label,img_path,img_filename,img

這個方法返回數據集中的樣本數量,它的返回值會被 PyTorch 的 DataLoader 使用。

1
2
3
#The total size of the dataset
def __len__(self):
return len(self.label_df)

Main code

這段代碼創建了一個 custom 類別的對象,然後使用 PyTorch 的 DataLoader 加載數據。之後,它遍歷了整個數據集,並打印出每個樣本的標籤、圖片路徑和圖片名稱。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
if __name__ =='__main__':
source_root = 'flower_dataset/image'
label_root = 'flower_dataset/labels.csv'
mydataset = custom(source_root, label_root)

# dataloader
dataloader=DataLoader(mydataset, batch_size=2, shuffle=True, num_workers=10)

#main
for batch_of_index , (img_label,img_path,img_filename,img) in enumerate(dataloader):
print(f"index->{batch_of_index + 1}")

for i in range(len(img_path)):
print(f"img_path:{img_path[i]}")
print(f"img_filename:{img_filename[i]}")
print(f"img_label:{img_label[i]}")

print("-"*20)

print(f"total data:{mydataset.__len__()}")



作者: 微風