-
Notifications
You must be signed in to change notification settings - Fork 33
/
comyui_dataset.py
42 lines (36 loc) · 954 Bytes
/
comyui_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from os.path import join as opj
import cv2
import numpy as np
from torch.utils.data import Dataset
class Comfyui_Dataset(Dataset):
def __init__(
self,
img_fn,
cloth_fn,
agn,
agn_mask,
cloth,
image,
image_densepose,
**kwargs
):
self.img_fn = img_fn
self.cloth_fn = cloth_fn
self.agn = agn
self.agn_mask = agn_mask
self.cloth = cloth
self.image = image
self.image_densepose = image_densepose
def __len__(self):
return 1
def __getitem__(self, idx):
return dict(
agn=self.agn,
agn_mask=self.agn_mask,
cloth=self.cloth,
image=self.image,
image_densepose=self.image_densepose,
txt="",
img_fn=self.img_fn,
cloth_fn=self.cloth_fn,
)