mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-12 05:46:22 +02:00
refactor datasets
This commit is contained in:
parent
c93390a001
commit
01728dc28a
38
configs/opensora-v1-1/inference/Vx256x256.py
Normal file
38
configs/opensora-v1-1/inference/Vx256x256.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
num_frames = None
|
||||
fps = 24 // 3
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="STDiT2-XL/2",
|
||||
space_scale=0.5,
|
||||
from_pretrained="PixArt-XL-2-1024-MS.pth",
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
micro_batch_size=4,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=120,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="iddpm",
|
||||
num_sampling_steps=100,
|
||||
cfg_scale=7.0,
|
||||
cfg_channel=3, # or None
|
||||
)
|
||||
dtype = "fp16"
|
||||
|
||||
# Condition
|
||||
prompt_path = "./assets/texts/t2v_samples.txt"
|
||||
prompt = None # prompt has higher priority than prompt_path
|
||||
|
||||
# Others
|
||||
batch_size = 1
|
||||
seed = 42
|
||||
save_dir = "./outputs/samples/"
|
||||
50
configs/opensora-v1-1/train/Vx256x256.py
Normal file
50
configs/opensora-v1-1/train/Vx256x256.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
# Define dataset
|
||||
data_path = "CSV_PATH"
|
||||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="STDiT2-XL/2",
|
||||
space_scale=0.5,
|
||||
from_pretrained="PixArt-XL-2-1024-MS.pth",
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
micro_batch_size=4,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=120,
|
||||
shardformer=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="iddpm-speed",
|
||||
timestep_respacing="",
|
||||
)
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 8
|
||||
lr = 2e-5
|
||||
grad_clip = 1.0
|
||||
|
|
@ -23,7 +23,7 @@ model = dict(
|
|||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
|
||||
# mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
|||
|
||||
from . import video_transforms
|
||||
from .utils import VID_EXTENSIONS
|
||||
from .utils import get_transforms_image, get_transforms_video
|
||||
|
||||
|
||||
class DatasetFromCSV(torch.utils.data.Dataset):
|
||||
|
|
@ -24,53 +25,61 @@ class DatasetFromCSV(torch.utils.data.Dataset):
|
|||
csv_path,
|
||||
num_frames=16,
|
||||
frame_interval=1,
|
||||
transform=None,
|
||||
root=None,
|
||||
image_size=(256, 256),
|
||||
):
|
||||
self.csv_path = csv_path
|
||||
self.data = pd.read_csv(csv_path)
|
||||
|
||||
ext = self.data["path"][0].split(".")[-1]
|
||||
if ext.lower() in VID_EXTENSIONS:
|
||||
self.is_video = True
|
||||
else:
|
||||
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||
self.is_video = False
|
||||
|
||||
self.transform = transform
|
||||
self.num_frames = num_frames
|
||||
self.frame_interval = frame_interval
|
||||
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
|
||||
self.root = root
|
||||
self.transforms = {
|
||||
"image": get_transforms_image(image_size[0]),
|
||||
"video": get_transforms_video(image_size[0]),
|
||||
}
|
||||
|
||||
def get_type(self, path):
|
||||
ext = path.split(".")[-1]
|
||||
if ext.lower() in VID_EXTENSIONS:
|
||||
return "video"
|
||||
else:
|
||||
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||
return "image"
|
||||
|
||||
def getitem(self, index):
|
||||
sample = self.data.iloc[index]
|
||||
path = sample["path"]
|
||||
if self.root:
|
||||
path = os.path.join(self.root, path)
|
||||
text = sample["text"]
|
||||
file_type = self.get_type(path)
|
||||
|
||||
if self.is_video:
|
||||
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
total_frames = len(vframes)
|
||||
if file_type == "video":
|
||||
# loading
|
||||
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
|
||||
# Sampling video frames
|
||||
total_frames = len(vframes)
|
||||
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
||||
assert (
|
||||
end_frame_ind - start_frame_ind >= self.num_frames
|
||||
), f"{path} with index {index} has not enough frames."
|
||||
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
||||
|
||||
video = vframes[frame_indice]
|
||||
video = self.transform(video) # T C H W
|
||||
|
||||
# transform
|
||||
transform = self.transforms["video"]
|
||||
video = transform(video) # T C H W
|
||||
else:
|
||||
# loading
|
||||
image = pil_loader(path)
|
||||
image = self.transform(image)
|
||||
|
||||
# transform
|
||||
transform = self.transforms["image"]
|
||||
image = transform(image)
|
||||
|
||||
# repeat
|
||||
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
|
||||
|
||||
# TCHW -> CTHW
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
|
||||
return {"video": video, "text": text}
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
|
|
|||
|
|
@ -195,7 +195,6 @@ class STDiT2(nn.Module):
|
|||
model_max_length=120,
|
||||
dtype=torch.float32,
|
||||
space_scale=1.0,
|
||||
time_scale=1.0,
|
||||
freeze=None,
|
||||
enable_flashattn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
|
|
@ -206,12 +205,6 @@ class STDiT2(nn.Module):
|
|||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.patch_size = patch_size
|
||||
self.input_size = input_size
|
||||
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
|
||||
self.num_patches = num_patches
|
||||
self.num_temporal = input_size[0] // patch_size[0]
|
||||
self.num_spatial = num_patches // self.num_temporal
|
||||
self.num_heads = num_heads
|
||||
self.dtype = dtype
|
||||
self.no_temporal_pos_emb = no_temporal_pos_emb
|
||||
|
|
@ -220,7 +213,15 @@ class STDiT2(nn.Module):
|
|||
self.enable_flashattn = enable_flashattn
|
||||
self.enable_layernorm_kernel = enable_layernorm_kernel
|
||||
self.space_scale = space_scale
|
||||
self.time_scale = time_scale
|
||||
|
||||
# support dynamic input
|
||||
self.patch_size = patch_size
|
||||
self.input_size = input_size
|
||||
if input_size[0] == None:
|
||||
self.num_temporal = None
|
||||
else:
|
||||
self.num_temporal = input_size[0] // patch_size[0]
|
||||
self.num_spatial = np.prod(input_size[1:]) // np.prod(patch_size[1:])
|
||||
|
||||
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
|
||||
# self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
|
||||
|
|
@ -420,15 +421,6 @@ class STDiT2(nn.Module):
|
|||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
|
||||
return pos_embed
|
||||
|
||||
def get_temporal_pos_embed(self):
|
||||
pos_embed = get_1d_sincos_pos_embed(
|
||||
self.hidden_size,
|
||||
self.input_size[0] // self.patch_size[0],
|
||||
scale=self.time_scale,
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
|
||||
return pos_embed
|
||||
|
||||
def freeze_not_temporal(self):
|
||||
for n, p in self.named_parameters():
|
||||
if "attn_temp" not in n:
|
||||
|
|
|
|||
|
|
@ -90,16 +90,10 @@ def main():
|
|||
# 3. build dataset and dataloader
|
||||
# ======================================================
|
||||
dataset = DatasetFromCSV(
|
||||
cfg.data_path,
|
||||
# TODO: change transforms
|
||||
transform=(
|
||||
get_transforms_video(cfg.image_size[0])
|
||||
if not cfg.use_image_transform
|
||||
else get_transforms_image(cfg.image_size[0])
|
||||
),
|
||||
csv_path=cfg.data_path,
|
||||
num_frames=cfg.num_frames,
|
||||
frame_interval=cfg.frame_interval,
|
||||
root=cfg.root,
|
||||
image_size=cfg.image_size,
|
||||
)
|
||||
|
||||
# TODO: use plugin's prepare dataloader
|
||||
|
|
@ -222,7 +216,7 @@ def main():
|
|||
x = vae.encode(x) # [B, C, T, H/P, W/P]
|
||||
# Prepare text inputs
|
||||
model_args = text_encoder.encode(y)
|
||||
|
||||
|
||||
# Mask
|
||||
if cfg.mask_ratios is not None:
|
||||
mask = mask_generator.get_masks(x)
|
||||
|
|
|
|||
Loading…
Reference in a new issue