Merge branch 'dev/v1.0.1' of github.com:hpcaitech/Open-Sora-dev into dev/Rflow

This commit is contained in:
tianyi 2024-04-19 11:18:01 +08:00
commit 0d98c1d02d
37 changed files with 398 additions and 126 deletions

1
.gitignore vendored
View file

@ -170,6 +170,7 @@ runs/
checkpoints/
outputs/
samples/
samples
logs/
pretrained_models/
*.swp

5
assets/texts/t2v_ref.txt Normal file
View file

@ -0,0 +1,5 @@
Drone view of waves crashing against the rugged cliffs along Big Surs garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliffs edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.
In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave.
Pirate ship in a cosmic maelstrom nebula.
Drone view of waves crashing against the rugged cliffs along Big Surs garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliffs edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.
A sad small cactus with in the Sahara desert becomes happy.

View file

@ -0,0 +1,20 @@
A fat rabbit wearing a purple robe walking through a fantasy landscape
Waves crashing against a lone lighthouse, ominous lighting
A mystical forest showcasing the adventures of travelers who enter
A blue-haired mage singing
A surreal landscape with floating islands and waterfalls in the sky craft
A blue bird standing in water
A young man walks alone by the seaside
Pink rose on a glass surface with droplets, close-up
Drove viewpoint, a subway train coming out of a tunnel
Space with all planets green and pink color with background of bright white stars
A city floating in an astral space, with stars and nebulae
Sunrise on top of a high-rise building
Pink and cyan powder explosions
Deers in the woods gaze into the camera under the sunlight
In a flash of lightning, a wizard appeared from thin air, his long robes billowing in the wind
A futuristic cyberpunk cityscape at night with towering neon-lit skyscrapers
A scene where the trees, flowers, and animals come together to create a symphony of nature
A ghostly ship sailing through the clouds, navigating through a sea under a moonlit sky
A sunset with beautiful beach
A young man walking alone in the forest

View file

@ -28,4 +28,4 @@ dtype = "bf16"
batch_size = 2
seed = 42
prompt_path = "./assets/texts/ucf101_labels.txt"
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -28,4 +28,4 @@ dtype = "bf16"
batch_size = 2
seed = 42
prompt_path = "./assets/texts/imagenet_id.txt"
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -29,4 +29,4 @@ dtype = "bf16"
batch_size = 2
seed = 42
prompt_path = "./assets/texts/imagenet_labels.txt"
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -27,4 +27,4 @@ dtype = "bf16"
batch_size = 2
seed = 42
prompt_path = "./assets/texts/ucf101_id.txt"
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -28,4 +28,4 @@ dtype = "bf16"
batch_size = 2
seed = 42
prompt_path = "./assets/texts/ucf101_labels.txt"
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -0,0 +1,54 @@
num_frames = 16
fps = 24 // 3
image_size = (240, 426)
multi_resolution = "STDiT2"
# Condition
prompt_path = None
prompt = None
loop = 10
condition_frame_length = 4
reference_path = [
"assets/images/condition/cliff.png",
"assets/images/condition/wave.png",
]
# valid when reference_path is not None
# (loop id, ref id, ref start, length, target start)
mask_strategy = [
"0,0,0,1,0",
"0,0,0,1,0",
]
# Define model
model = dict(
type="STDiT2-XL/2",
from_pretrained=None,
input_sq_size=512,
enable_flashattn=True,
enable_layernorm_kernel=True,
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
cache_dir=None, # "/mnt/hdd/cached_models",
micro_batch_size=4,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
cache_dir=None, # "/mnt/hdd/cached_models",
model_max_length=200,
)
scheduler = dict(
type="iddpm",
num_sampling_steps=100,
cfg_scale=7.0,
cfg_channel=3, # or None
)
dtype = "bf16"
# Others
batch_size = 1
seed = 42
save_dir = "./samples/samples/"

View file

@ -1,6 +1,6 @@
num_frames = 16
fps = 24 // 3
image_size = (256, 256)
image_size = (240, 426)
multi_resolution = "STDiT2"
# Define model
@ -21,7 +21,7 @@ text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
cache_dir=None, # "/mnt/hdd/cached_models",
model_max_length=300,
model_max_length=200,
)
scheduler = dict(
type="iddpm",
@ -38,4 +38,4 @@ prompt = None # prompt has higher priority than prompt_path
# Others
batch_size = 1
seed = 42
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -18,10 +18,14 @@ bucket_config = { # 6s/it
}
mask_ratios = {
"mask_no": 0.9,
"mask_random": 0.06,
"mask_head": 0.01,
"mask_tail": 0.01,
"mask_head_tail": 0.02,
"mask_quarter_random": 0.01,
"mask_quarter_head": 0.01,
"mask_quarter_tail": 0.01,
"mask_quarter_head_tail": 0.02,
"mask_image_random": 0.01,
"mask_image_head": 0.01,
"mask_image_tail": 0.01,
"mask_image_head_tail": 0.02,
}
# Define acceleration

View file

@ -52,4 +52,4 @@ mask_strategy = [
# Others
batch_size = 2
seed = 42
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -47,4 +47,4 @@ mask_strategy = ["0,0,0,1,0", "0,0,0,1,0"] # valid when reference_path is not N
# Others
batch_size = 2
seed = 42
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -50,4 +50,4 @@ mask_strategy = ["0,0,0,1,0;0,0,0,1,-1", "0,0,0,1,0;0,1,0,1,-1"] # valid when r
# Others
batch_size = 2
seed = 42
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -47,4 +47,4 @@ mask_strategy = ["0,0,0,1,0", "0,0,0,1,0"] # valid when reference_path is not N
# Others
batch_size = 2
seed = 42
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -36,4 +36,4 @@ prompt = None # prompt has higher priority than prompt_path
# Others
batch_size = 1
seed = 42
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -32,4 +32,4 @@ dtype = "bf16"
batch_size = 2
seed = 42
prompt_path = "./assets/texts/t2v_samples.txt"
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -32,4 +32,4 @@ dtype = "bf16"
batch_size = 1
seed = 42
prompt_path = "./assets/texts/t2v_samples.txt"
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -29,4 +29,4 @@ dtype = "bf16"
batch_size = 2
seed = 42
prompt_path = "./assets/texts/t2v_samples.txt"
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -31,4 +31,4 @@ dtype = "bf16"
batch_size = 2
seed = 42
prompt_path = "./assets/texts/t2i_samples.txt"
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -30,4 +30,4 @@ dtype = "bf16"
batch_size = 2
seed = 42
prompt_path = "./assets/texts/t2i_samples.txt"
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -36,4 +36,4 @@ prompt = [
# Others
batch_size = 2
seed = 42
save_dir = "./outputs/samples/"
save_dir = "./samples/samples/"

View file

@ -71,6 +71,19 @@ vae = dict(
)
```
### Evalution
Use the following commands to generate predefined samples.
```bash
# image
bash scripts/misc/sample.sh /path/to/ckpt --image
# video
bash scripts/misc/sample.sh /path/to/ckpt --video
# video edit
bash scripts/misc/sample.sh /path/to/ckpt --video-edit
```
## Training
To resume training, run the following command. ``--load`` different from ``--ckpt-path`` as it loads the optimizer and dataloader states.
@ -109,6 +122,7 @@ python tools/datasets/split.py YOUR_CSV_PATH -o YOUR_SUBSET_CSV_PATH -c configs/
If you want to control the batch size search more granularly, you can configure batch size start, end, and step in the config file.
Bucket config format:
1. `{ resolution: {num_frames: (prob, batch_size)} }`, in this case batch_size is ignored when searching
2. `{ resolution: {num_frames: (prob, (max_batch_size, ))} }`, batch_size is searched in the range `[batch_size_start, max_batch_size)`, batch_size_start is configured via CLI
3. `{ resolution: {num_frames: (prob, (min_batch_size, max_batch_size))} }`, batch_size is searched in the range `[min_batch_size, max_batch_size)`
@ -135,4 +149,4 @@ bucket_config = {
}
```
It will print the best batch size (and corresponding step time) for each bucket and save the output config file.
It will print the best batch size (and corresponding step time) for each bucket and save the output config file.

View file

@ -20,6 +20,8 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"OPEN_SORA_HOME = \"/home/zhaowangbo/zangwei/opensora/\"\n",
"\n",
"\n",
@ -47,6 +49,7 @@
" )\n",
" return \" && \".join(commands), output_file\n",
"\n",
"\n",
"def get_video_info_torchvision(input_file):\n",
" commands = []\n",
" base, ext = os.path.splitext(input_file)\n",
@ -173,9 +176,7 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import paramiko\n",
"import time\n",
"\n",
"HOSTS = [\"h800-80\", \"h800-81\", \"h800-82\", \"h800-83\", \"h800-84\", \"h800-85\", \"h800-86\", \"h800-170\", \"h800-171\"]\n",
"\n",

View file

@ -92,7 +92,8 @@ class VideoTextDataset(torch.utils.data.Dataset):
try:
return self.getitem(index)
except Exception as e:
print(e)
path = self.data.iloc[index]["path"]
print(f"data {path}: {e}")
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")

View file

@ -86,32 +86,32 @@ def get_transforms_image(name="center", image_size=(256, 256)):
return transform
def read_image_from_path(path, transform=None, num_frames=1, image_size=(256, 256)):
def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
image = pil_loader(path)
if transform is None:
transform = get_transforms_image(image_size=image_size)
transform = get_transforms_image(image_size=image_size, name=transform_name)
image = transform(image)
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
video = video.permute(1, 0, 2, 3)
return video
def read_video_from_path(path, transform=None, image_size=(256, 256)):
def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
if transform is None:
transform = get_transforms_video(image_size=image_size)
transform = get_transforms_video(image_size=image_size, name=transform_name)
video = transform(vframes) # T C H W
video = video.permute(1, 0, 2, 3)
return video
def read_from_path(path, image_size):
def read_from_path(path, image_size, transform_name="center"):
ext = os.path.splitext(path)[-1].lower()
if ext.lower() in VID_EXTENSIONS:
return read_video_from_path(path, image_size=image_size)
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
else:
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return read_image_from_path(path, image_size=image_size)
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)):

View file

@ -356,7 +356,6 @@ class STDiT2(nn.Module):
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
# mask[:, 100:] = 0
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()

View file

@ -32,31 +32,40 @@ pretrained_models = {
def reparameter(ckpt, name=None, model=None):
if "DiT-XL" in name and "STDiT" not in name:
if name in ["DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"]:
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
elif "Latte" in name:
if name in ["Latte-XL-2-256x256-ucf101.pt"]:
ckpt = ckpt["ema"]
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
del ckpt["temp_embed"]
elif "PixArt" in name:
if name in ["PixArt-XL-2-256x256.pth", "PixArt-XL-2-SAM-256x256.pth", "PixArt-XL-2-512x512.pth"]:
ckpt = ckpt["state_dict"]
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
# different text length
if "y_embedder.y_embedding" in ckpt:
if ckpt["y_embedder.y_embedding"].shape[0] < model.y_embedder.y_embedding.shape[0]:
additional_length = model.y_embedder.y_embedding.shape[0] - ckpt["y_embedder.y_embedding"].shape[0]
new_y_embedding = torch.randn(additional_length, model.y_embedder.y_embedding.shape[1])
ckpt["y_embedder.y_embedding"] = torch.cat([ckpt["y_embedder.y_embedding"], new_y_embedding], dim=0)
elif ckpt["y_embedder.y_embedding"].shape[0] > model.y_embedder.y_embedding.shape[0]:
ckpt["y_embedder.y_embedding"] = ckpt["y_embedder.y_embedding"][: model.y_embedder.y_embedding.shape[0]]
# no need pos_embed
if "pos_embed_temporal" in ckpt:
del ckpt["pos_embed_temporal"]
if "pos_embed" in ckpt:
del ckpt["pos_embed"]
# different text length
if "y_embedder.y_embedding" in ckpt:
if ckpt["y_embedder.y_embedding"].shape[0] < model.y_embedder.y_embedding.shape[0]:
print(
f"Extend y_embedding from {ckpt['y_embedder.y_embedding'].shape[0]} to {model.y_embedder.y_embedding.shape[0]}"
)
additional_length = model.y_embedder.y_embedding.shape[0] - ckpt["y_embedder.y_embedding"].shape[0]
new_y_embedding = torch.zeros(additional_length, model.y_embedder.y_embedding.shape[1])
new_y_embedding[:] = ckpt["y_embedder.y_embedding"][-1]
ckpt["y_embedder.y_embedding"] = torch.cat([ckpt["y_embedder.y_embedding"], new_y_embedding], dim=0)
elif ckpt["y_embedder.y_embedding"].shape[0] > model.y_embedder.y_embedding.shape[0]:
print(
f"Shrink y_embedding from {ckpt['y_embedder.y_embedding'].shape[0]} to {model.y_embedder.y_embedding.shape[0]}"
)
ckpt["y_embedder.y_embedding"] = ckpt["y_embedder.y_embedding"][: model.y_embedder.y_embedding.shape[0]]
return ckpt

View file

@ -30,9 +30,15 @@ def parse_args(training=False):
# Inference
# ======================================================
if not training:
# output
parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples")
parser.add_argument("--sample-name", default=None, type=str, help="sample name, default is sample_idx")
parser.add_argument("--start-index", default=None, type=int, help="start index for sample name")
parser.add_argument("--end-index", default=None, type=int, help="end index for sample name")
# prompt
parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file")
parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples")
parser.add_argument("--prompt", default=None, type=str, nargs="+", help="prompt list")
# image/video
parser.add_argument("--num-frames", default=None, type=int, help="number of frames")
@ -42,6 +48,12 @@ def parse_args(training=False):
# hyperparameters
parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps")
parser.add_argument("--cfg-scale", default=None, type=float, help="balance between cond & uncond")
# reference
parser.add_argument("--loop", default=None, type=int, help="loop")
parser.add_argument("--condition-frame-length", default=None, type=int, help="condition frame length")
parser.add_argument("--reference-path", default=None, type=str, nargs="+", help="reference path")
parser.add_argument("--mask-strategy", default=None, type=str, nargs="+", help="mask strategy")
# ======================================================
# Training
# ======================================================
@ -57,31 +69,40 @@ def merge_args(cfg, args, training=False):
if args.ckpt_path is not None:
cfg.model["from_pretrained"] = args.ckpt_path
args.ckpt_path = None
if training and args.data_path is not None:
cfg.dataset["data_path"] = args.data_path
args.data_path = None
for k, v in vars(args).items():
if k in cfg and v is not None:
if v is not None:
cfg[k] = v
if not training:
# Inference only
# - Allow not set
if "reference_path" not in cfg:
cfg["reference_path"] = None
if "loop" not in cfg:
cfg["loop"] = 1
# - Prompt handling
if "prompt" not in cfg or cfg["prompt"] is None:
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
cfg["prompt"] = load_prompts(cfg["prompt_path"])
if args.start_index is not None and args.end_index is not None:
cfg["prompt"] = cfg["prompt"][args.start_index : args.end_index]
elif args.start_index is not None:
cfg["prompt"] = cfg["prompt"][args.start_index :]
elif args.end_index is not None:
cfg["prompt"] = cfg["prompt"][: args.end_index]
else:
# Training only
if args.data_path is not None:
cfg.dataset["data_path"] = args.data_path
args.data_path = None
# - Allow not set
if "mask_ratios" not in cfg:
cfg["mask_ratios"] = None
if "transform_name" not in cfg.dataset:
cfg.dataset["transform_name"] = "center"
if "bucket_config" not in cfg:
cfg["bucket_config"] = None
if "transform_name" not in cfg.dataset:
cfg.dataset["transform_name"] = "center"
# Both training and inference
if "multi_resolution" not in cfg:

View file

@ -35,7 +35,17 @@ def update_ema(
class MaskGenerator:
def __init__(self, mask_ratios):
valid_mask_names = ["mask_no", "mask_random", "mask_head", "mask_tail", "mask_head_tail"]
valid_mask_names = [
"mask_no",
"mask_quarter_random",
"mask_quarter_head",
"mask_quarter_tail",
"mask_quarter_head_tail",
"mask_image_random",
"mask_image_head",
"mask_image_tail",
"mask_image_head_tail",
]
assert all(
mask_name in valid_mask_names for mask_name in mask_ratios.keys()
), f"mask_name should be one of {valid_mask_names}, got {mask_ratios.keys()}"
@ -70,21 +80,34 @@ class MaskGenerator:
if num_frames <= 1:
return mask
if mask_name == "mask_random":
if mask_name == "mask_quarter_random":
random_size = random.randint(1, condition_frames_max)
random_pos = random.randint(0, x.shape[2] - random_size)
mask[random_pos : random_pos + random_size] = 0
return mask
elif mask_name == "mask_head":
elif mask_name == "mask_image_random":
random_size = 1
random_pos = random.randint(0, x.shape[2] - random_size)
mask[random_pos : random_pos + random_size] = 0
elif mask_name == "mask_quarter_head":
random_size = random.randint(1, condition_frames_max)
mask[:random_size] = 0
elif mask_name == "mask_tail":
elif mask_name == "mask_image_head":
random_size = 1
mask[:random_size] = 0
elif mask_name == "mask_quarter_tail":
random_size = random.randint(1, condition_frames_max)
mask[-random_size:] = 0
elif mask_name == "mask_head_tail":
elif mask_name == "mask_image_tail":
random_size = 1
mask[-random_size:] = 0
elif mask_name == "mask_quarter_head_tail":
random_size = random.randint(1, condition_frames_max)
mask[:random_size] = 0
mask[-random_size:] = 0
elif mask_name == "mask_image_head_tail":
random_size = 1
mask[:random_size] = 0
mask[-random_size:] = 0
return mask

View file

@ -13,3 +13,4 @@ timm
tqdm
transformers
wandb
rotary_embedding_torch

View file

@ -21,7 +21,7 @@ def collect_references_batch(reference_paths, vae, image_size):
ref_path = reference_path.split(";")
ref = []
for r_path in ref_path:
r = read_from_path(r_path, image_size)
r = read_from_path(r_path, image_size, transform_name="resize_crop")
r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
r_x = r_x.squeeze(0)
ref.append(r_x)
@ -85,13 +85,18 @@ def main():
print(cfg)
# init distributed
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if os.environ.get("WORLD_SIZE", None):
use_dist = True
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
else:
enable_sequence_parallelism = False
else:
use_dist = False
enable_sequence_parallelism = False
# ======================================================
@ -159,6 +164,7 @@ def main():
# 4. inference
# ======================================================
sample_idx = 0
sample_name = cfg.sample_name if cfg.sample_name is not None else "sample"
save_dir = cfg.save_dir
os.makedirs(save_dir, exist_ok=True)
@ -206,14 +212,14 @@ def main():
# 4.7. save video
if loop_i == cfg.loop - 1:
if coordinator.is_master():
if not use_dist or coordinator.is_master():
for idx in range(len(video_clips[0])):
video_clips_i = [video_clips[0][idx]] + [
video_clips[i][idx][:, cfg.condition_frame_length :] for i in range(1, cfg.loop)
]
video = torch.cat(video_clips_i, dim=1)
print(f"Prompt: {prompts[i + idx]}")
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
save_path = os.path.join(save_dir, f"{sample_name}_{sample_idx}")
save_sample(video, fps=cfg.fps, save_path=save_path)
sample_idx += 1

View file

@ -22,13 +22,18 @@ def main():
print(cfg)
# init distributed
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if os.environ.get("WORLD_SIZE", None):
use_dist = True
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
else:
enable_sequence_parallelism = False
else:
use_dist = False
enable_sequence_parallelism = False
# ======================================================
@ -91,6 +96,7 @@ def main():
# 4. inference
# ======================================================
sample_idx = 0
sample_name = cfg.sample_name if cfg.sample_name is not None else "sample"
save_dir = cfg.save_dir
os.makedirs(save_dir, exist_ok=True)
@ -112,10 +118,10 @@ def main():
)
samples = vae.decode(samples.to(dtype))
if coordinator.is_master():
if not use_dist or coordinator.is_master():
for idx, sample in enumerate(samples):
print(f"Prompt: {batch_prompts[idx]}")
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
save_path = os.path.join(save_dir, f"{sample_name}_{sample_idx}")
save_sample(sample, fps=cfg.fps, save_path=save_path)
sample_idx += 1

141
scripts/misc/sample.sh Normal file
View file

@ -0,0 +1,141 @@
#!/bin/bash
set -x
set -e
CKPT=$1
CMD="python scripts/inference.py configs/opensora-v1-1/inference/sample.py"
CMD_REF="python scripts/inference-long.py configs/opensora-v1-1/inference/sample.py"
if [[ $CKPT == *"ema"* ]]; then
parentdir=$(dirname $CKPT)
CKPT_BASE=$(basename $parentdir)_ema
else
CKPT_BASE=$(basename $CKPT)
fi
OUTPUT="./samples/samples_${CKPT_BASE}"
start=$(date +%s)
### Functions
function run_image() {
# 1.1 1024x1024
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2i_samples.txt --save-dir $OUTPUT --num-frames 1 --image-size 1024 1024 --sample-name 1024x1024
# 1.2 240x426
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2i_samples.txt --save-dir $OUTPUT --num-frames 1 --image-size 240 426 --sample-name 240x426 --end-index 3
# 1.3 512x512
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2i_samples.txt --save-dir $OUTPUT --num-frames 1 --image-size 512 512 --sample-name 512x512 --end-index 3
# 1.4 720p multi-resolution
# 1:1
PROMPT="Bright scene, aerial view,ancient city, fantasy, gorgeous light, mirror reflection, high detail, wide angle lens."
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 1 --image-size 960 960 --sample-name 720p_1_1
# 16:9
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 1 --image-size 720 1280 --sample-name 720p_16_9
# 9:16
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 1 --image-size 1280 720 --sample-name 720p_9_16
# 4:3
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 1 --image-size 832 1108 --sample-name 720p_4_3
# 3:4
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 1 --image-size 1108 832 --sample-name 720p_3_4
# 1:2
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 1 --image-size 1358 600 --sample-name 720p_1_2
# 2:1
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 1 --image-size 600 1358 --sample-name 720p_2_1
}
function run_video_1() {
# 2.1.1 16x240x426
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 16 --image-size 240 426 --sample-name sample_16x240x426
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_short.txt --save-dir $OUTPUT --num-frames 16 --image-size 240 426 --sample-name short_16x240x426
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_sora.txt --save-dir $OUTPUT --num-frames 16 --image-size 240 426 --sample-name sora_16x240x426
# 2.1.2 16x720p multi-resolution
# 1:1
PROMPT="A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures."
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 16 --image-size 960 960 --sample-name 720p_1_1
# 16:9
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 16 --image-size 720 1280 --sample-name 720p_16_9
# 9:16
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 16 --image-size 1280 720 --sample-name 720p_9_16
# 4:3
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 16 --image-size 832 1108 --sample-name 720p_4_3
# 3:4
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 16 --image-size 1108 832 --sample-name 720p_3_4
# 1:2
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 16 --image-size 1358 600 --sample-name 720p_1_2
# 2:1
eval $CMD --ckpt-path $CKPT --prompt \"$PROMPT\" --save-dir $OUTPUT --num-frames 16 --image-size 600 1358 --sample-name 720p_2_1
}
function run_video_2() {
# 2.2.1 64x240x426
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_short.txt --save-dir $OUTPUT --num-frames 64 --image-size 240 426 --sample-name short_64x240x426
# 2.2.2 128x240x426
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_short.txt --save-dir $OUTPUT --num-frames 128 --image-size 240 426 --sample-name short_128x240x426
# 2.2.3 16x480x854
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_short.txt --save-dir $OUTPUT --num-frames 16 --image-size 480 854 --sample-name short_16x480x854
# 2.2.4 64x480x854
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_short.txt --save-dir $OUTPUT --num-frames 64 --image-size 480 854 --sample-name short_64x480x854
# 2.2.5 16x720x1280
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_short.txt --save-dir $OUTPUT --num-frames 16 --image-size 720 1280 --sample-name short_16x720x1280
}
function run_video_edit() {
# 3.1 image-conditioned long video generation
eval $CMD_REF --ckpt-path $CKPT --save-dir $OUTPUT --sample-name ref_L10C4_16x240x426 \
--prompt-path assets/texts/t2v_ref.txt --start-index 0 --end-index 3 \
--num-frames 16 --image-size 240 426 \
--loop 5 --condition-frame-length 4 \
--reference-path assets/images/condition/cliff.png assets/images/condition/wave.png \
--mask-strategy "0,0,0,1,0" "0,0,0,1,0"
eval $CMD_REF --ckpt-path $CKPT --save-dir $OUTPUT --sample-name ref_L10C4_64x240x426 \
--prompt-path assets/texts/t2v_ref.txt --start-index 0 --end-index 3 \
--num-frames 64 --image-size 240 426 \
--loop 5 --condition-frame-length 16 \
--reference-path assets/images/condition/cliff.png assets/images/condition/wave.png \
--mask-strategy "0,0,0,1,0" "0,0,0,1,0"
# 3.2
eval $CMD_REF --ckpt-path $CKPT --save-dir $OUTPUT --sample-name ref_L1_128x240x426 \
--prompt-path assets/texts/t2v_ref.txt --start-index 3 --end-index 5 \
--num-frames 128 --image-size 240 426 \
--loop 1 \
--reference-path assets/images/condition/cliff.png "assets/images/condition/cactus-sad.png\;assets/images/condition/cactus-happy.png" \
--mask-strategy "0,0,0,1,0\;0,0,0,1,-1" "0,0,0,1,0\;0,1,0,1,-1"
}
### Main
for arg in "$@"; do
if [[ "$arg" = -1 ]] || [[ "$arg" = --image ]]; then
run_image
echo "Running image samples..."
fi
if [[ "$arg" = -2_1 ]] || [[ "$arg" = --video ]]; then
run_video_1
echo "Running video samples 1..."
fi
if [[ "$arg" = -2_2 ]] || [[ "$arg" = --video ]]; then
run_video_2
echo "Running video samples 2..."
fi
if [[ "$arg" = -3 ]] || [[ "$arg" = --video-edit ]]; then
run_video_edit
echo "Running video edit samples..."
fi
done
### End
end=$(date +%s)
runtime=$((end - start))
echo "Runtime: $runtime seconds"

View file

@ -25,12 +25,7 @@ from opensora.datasets import prepare_variable_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import model_sharding
from opensora.utils.config_utils import merge_args, parse_configs
from opensora.utils.misc import (
format_numel_str,
get_model_numel,
requires_grad,
to_torch_dtype,
)
from opensora.utils.misc import format_numel_str, get_model_numel, requires_grad, to_torch_dtype
from opensora.utils.train_utils import MaskGenerator, update_ema
@ -66,9 +61,7 @@ class BColors:
def parse_configs():
parser = argparse.ArgumentParser()
parser.add_argument("config", help="model config file path")
parser.add_argument(
"-o", "--output", help="output config file path", default="output_config.py"
)
parser.add_argument("-o", "--output", help="output config file path", default="output_config.py")
parser.add_argument("--seed", default=42, type=int, help="generation seed")
parser.add_argument(
@ -76,24 +69,14 @@ def parse_configs():
type=str,
help="path to model ckpt; will overwrite cfg.ckpt_path if specified",
)
parser.add_argument(
"--data-path", default=None, type=str, help="path to data csv", required=True
)
parser.add_argument("--data-path", default=None, type=str, help="path to data csv", required=True)
parser.add_argument("--warmup-steps", default=1, type=int, help="warmup steps")
parser.add_argument("--active-steps", default=1, type=int, help="active steps")
parser.add_argument(
"--base-resolution", default="240p", type=str, help="base resolution"
)
parser.add_argument("--base-resolution", default="240p", type=str, help="base resolution")
parser.add_argument("--base-frames", default=128, type=int, help="base frames")
parser.add_argument(
"--batch-size-start", default=2, type=int, help="batch size start"
)
parser.add_argument(
"--batch-size-end", default=256, type=int, help="batch size end"
)
parser.add_argument(
"--batch-size-step", default=2, type=int, help="batch size step"
)
parser.add_argument("--batch-size-start", default=2, type=int, help="batch size start")
parser.add_argument("--batch-size-end", default=256, type=int, help="batch size end")
parser.add_argument("--batch-size-step", default=2, type=int, help="batch size step")
args = parser.parse_args()
cfg = Config.fromfile(args.config)
cfg = merge_args(cfg, args, training=True)
@ -116,9 +99,7 @@ def main():
# ======================================================
cfg, args = parse_configs()
print(cfg)
assert (
cfg.dataset.type == "VariableVideoTextDataset"
), "Only VariableVideoTextDataset is supported"
assert cfg.dataset.type == "VariableVideoTextDataset", "Only VariableVideoTextDataset is supported"
# ======================================================
# 2. runtime variables & colossalai launch
@ -223,10 +204,7 @@ def main():
model_sharding(ema)
buckets = [
(res, f)
for res, d in cfg.bucket_config.items()
for f, (p, bs) in d.items()
if bs is not None and p > 0.0
(res, f) for res, d in cfg.bucket_config.items() for f, (p, bs) in d.items() if bs is not None and p > 0.0
]
output_bucket_cfg = deepcopy(cfg.bucket_config)
# find the base batch size
@ -248,15 +226,11 @@ def main():
optimizer,
ema,
)
update_bucket_config_bs(
output_bucket_cfg, args.base_resolution, args.base_frames, base_batch_size
)
update_bucket_config_bs(output_bucket_cfg, args.base_resolution, args.base_frames, base_batch_size)
coordinator.print_on_master(
f"{BColors.OKBLUE}Base resolution: {args.base_resolution}, Base frames: {args.base_frames}, Batch size: {base_batch_size}, Base step time: {base_step_time}{BColors.ENDC}"
)
result_table = [
f"{args.base_resolution}, {args.base_frames}, {base_batch_size}, {base_step_time:.2f}"
]
result_table = [f"{args.base_resolution}, {args.base_frames}, {base_batch_size}, {base_step_time:.2f}"]
for resolution, frames in buckets:
try:
batch_size, step_time = benchmark(
@ -280,9 +254,7 @@ def main():
f"{BColors.OKBLUE}Resolution: {resolution}, Frames: {frames}, Batch size: {batch_size}, Step time: {step_time}{BColors.ENDC}"
)
update_bucket_config_bs(output_bucket_cfg, resolution, frames, batch_size)
result_table.append(
f"{resolution}, {frames}, {batch_size}, {step_time:.2f}"
)
result_table.append(f"{resolution}, {frames}, {batch_size}, {step_time:.2f}")
except RuntimeError:
pass
result_table = "\n".join(result_table)
@ -367,10 +339,7 @@ def benchmark(
raise RuntimeError("No valid batch size found")
if target_step_time is None:
# find the fastest batch size
throughputs = [
batch_size / step_time
for step_time, batch_size in zip(step_times, batch_sizes)
]
throughputs = [batch_size / step_time for step_time, batch_size in zip(step_times, batch_sizes)]
max_throughput = max(throughputs)
target_batch_size = batch_sizes[throughputs.index(max_throughput)]
step_time = step_times[throughputs.index(max_throughput)]
@ -419,13 +388,9 @@ def train(
**dataloader_args,
)
dataloader_iter = iter(dataloader)
num_steps_per_epoch = (
dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
)
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
assert (
num_steps_per_epoch >= total_steps
), f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}"
assert num_steps_per_epoch >= total_steps, f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}"
duration = 0
# this is essential for the first iteration after OOM
optimizer._grad_store.reset_all_gradients()

View file

@ -4,7 +4,6 @@ from pprint import pprint
import torch
import torch.distributed as dist
import wandb
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
@ -12,6 +11,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import (
get_data_parallel_group,
@ -99,6 +99,7 @@ def main():
dataset=dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
seed=cfg.seed,
shuffle=True,
drop_last=True,
pin_memory=True,

View file

@ -17,7 +17,7 @@ def read_file(input_path):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="Path to the input dataset")
parser.add_argument("--save-img", type=str, default="outputs/samples/infos/", help="Path to save the image")
parser.add_argument("--save-img", type=str, default="samples/samples/infos/", help="Path to save the image")
return parser.parse_args()