-pre install
1 | pip install torch |
PyTorch is a powerful deep learning framework that offers many tools and libraries to simplify the process of model creation and training. Its advantages include dynamic computation graphs, an easy-to-use API, and a vibrant community support. To install PyTorch, we can use the following pip command:
1 | import torch |
定義一個自定義的 Dataset 類別:
定義了一個叫做 custom 的類別,這個類別繼承了 PyTorch 的 Dataset 類別。
1 | class custom(Dataset): |
初始化方法 __init__(self, source_root, label_root): 這個方法會在創建類別的對象時被調用。它接受兩個參數,分別是圖片文件的路徑(source_root)和標籤文件的路徑(label_root)。它讀取了標籤文件並將其存儲在 self.label_df 中。
1 | #Initialize paths |
這個方法根據指定的索引返回一個數據樣本和它的標籤。它首先從標籤文件中讀取指定索引的圖片名稱和標籤,然後讀取並處理對應的圖片文件,最後返回這些信息。
1 | #Read from file(opencv, pandas, ...) and Return the data (e.g. image and label) |
這個方法返回數據集中的樣本數量,它的返回值會被 PyTorch 的 DataLoader 使用。
1 | #The total size of the dataset |
Main code
這段代碼創建了一個 custom 類別的對象,然後使用 PyTorch 的 DataLoader 加載數據。之後,它遍歷了整個數據集,並打印出每個樣本的標籤、圖片路徑和圖片名稱。
1 | if __name__ =='__main__': |

作者: 微風