refactor datasets

This commit is contained in:
Zangwei Zheng 2024-03-26 16:50:36 +08:00
parent c93390a001
commit 01728dc28a
6 changed files with 131 additions and 48 deletions

View file

@ -0,0 +1,38 @@
num_frames = None
fps = 24 // 3
image_size = (256, 256)
# Define model
model = dict(
type="STDiT2-XL/2",
space_scale=0.5,
from_pretrained="PixArt-XL-2-1024-MS.pth",
enable_flashattn=True,
enable_layernorm_kernel=True,
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
micro_batch_size=4,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=120,
)
scheduler = dict(
type="iddpm",
num_sampling_steps=100,
cfg_scale=7.0,
cfg_channel=3, # or None
)
dtype = "fp16"
# Condition
prompt_path = "./assets/texts/t2v_samples.txt"
prompt = None # prompt has higher priority than prompt_path
# Others
batch_size = 1
seed = 42
save_dir = "./outputs/samples/"

View file

@ -0,0 +1,50 @@
# Define dataset
data_path = "CSV_PATH"
num_frames = 16
frame_interval = 3
image_size = (256, 256)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
model = dict(
type="STDiT2-XL/2",
space_scale=0.5,
from_pretrained="PixArt-XL-2-1024-MS.pth",
enable_flashattn=True,
enable_layernorm_kernel=True,
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
micro_batch_size=4,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=120,
shardformer=True,
)
scheduler = dict(
type="iddpm-speed",
timestep_respacing="",
)
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 1000
log_every = 10
ckpt_every = 1000
load = None
batch_size = 8
lr = 2e-5
grad_clip = 1.0

View file

@ -23,7 +23,7 @@ model = dict(
enable_flashattn=True,
enable_layernorm_kernel=True,
)
mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
# mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",

View file

@ -8,6 +8,7 @@ from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import VID_EXTENSIONS
from .utils import get_transforms_image, get_transforms_video
class DatasetFromCSV(torch.utils.data.Dataset):
@ -24,53 +25,61 @@ class DatasetFromCSV(torch.utils.data.Dataset):
csv_path,
num_frames=16,
frame_interval=1,
transform=None,
root=None,
image_size=(256, 256),
):
self.csv_path = csv_path
self.data = pd.read_csv(csv_path)
ext = self.data["path"][0].split(".")[-1]
if ext.lower() in VID_EXTENSIONS:
self.is_video = True
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
self.is_video = False
self.transform = transform
self.num_frames = num_frames
self.frame_interval = frame_interval
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
self.root = root
self.transforms = {
"image": get_transforms_image(image_size[0]),
"video": get_transforms_video(image_size[0]),
}
def get_type(self, path):
ext = path.split(".")[-1]
if ext.lower() in VID_EXTENSIONS:
return "video"
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return "image"
def getitem(self, index):
sample = self.data.iloc[index]
path = sample["path"]
if self.root:
path = os.path.join(self.root, path)
text = sample["text"]
file_type = self.get_type(path)
if self.is_video:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
if file_type == "video":
# loading
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
# Sampling video frames
total_frames = len(vframes)
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= self.num_frames
), f"{path} with index {index} has not enough frames."
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
video = vframes[frame_indice]
video = self.transform(video) # T C H W
# transform
transform = self.transforms["video"]
video = transform(video) # T C H W
else:
# loading
image = pil_loader(path)
image = self.transform(image)
# transform
transform = self.transforms["image"]
image = transform(image)
# repeat
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text}
def __getitem__(self, index):

View file

@ -195,7 +195,6 @@ class STDiT2(nn.Module):
model_max_length=120,
dtype=torch.float32,
space_scale=1.0,
time_scale=1.0,
freeze=None,
enable_flashattn=False,
enable_layernorm_kernel=False,
@ -206,12 +205,6 @@ class STDiT2(nn.Module):
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.num_heads = num_heads
self.dtype = dtype
self.no_temporal_pos_emb = no_temporal_pos_emb
@ -220,7 +213,15 @@ class STDiT2(nn.Module):
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.space_scale = space_scale
self.time_scale = time_scale
# support dynamic input
self.patch_size = patch_size
self.input_size = input_size
if input_size[0] == None:
self.num_temporal = None
else:
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = np.prod(input_size[1:]) // np.prod(patch_size[1:])
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
# self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
@ -420,15 +421,6 @@ class STDiT2(nn.Module):
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:

View file

@ -90,16 +90,10 @@ def main():
# 3. build dataset and dataloader
# ======================================================
dataset = DatasetFromCSV(
cfg.data_path,
# TODO: change transforms
transform=(
get_transforms_video(cfg.image_size[0])
if not cfg.use_image_transform
else get_transforms_image(cfg.image_size[0])
),
csv_path=cfg.data_path,
num_frames=cfg.num_frames,
frame_interval=cfg.frame_interval,
root=cfg.root,
image_size=cfg.image_size,
)
# TODO: use plugin's prepare dataloader
@ -222,7 +216,7 @@ def main():
x = vae.encode(x) # [B, C, T, H/P, W/P]
# Prepare text inputs
model_args = text_encoder.encode(y)
# Mask
if cfg.mask_ratios is not None:
mask = mask_generator.get_masks(x)