mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-21 11:59:01 +02:00
[hotfix] fix vqvae output process (#9)
This commit is contained in:
parent
da9b00e808
commit
6f887f453b
|
|
@ -2,10 +2,13 @@ import os
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.utils import get_current_device
|
||||
from datasets import Dataset as HFDataset
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
from torchvision.io import read_video
|
||||
|
||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
PathType = Union[str, os.PathLike]
|
||||
|
|
@ -34,7 +37,9 @@ def video2col(video_4d: torch.Tensor, patch_size: int) -> torch.Tensor:
|
|||
return torch.stack(out, dim=1).view(-1, c, patch_size, patch_size)
|
||||
|
||||
|
||||
def col2video(patches: torch.Tensor, video_shape: Tuple[int, int, int, int]) -> torch.Tensor:
|
||||
def col2video(
|
||||
patches: torch.Tensor, video_shape: Tuple[int, int, int, int]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a 2D tensor of patches to a 4D video tensor.
|
||||
|
||||
|
|
@ -71,7 +76,10 @@ def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Te
|
|||
"""
|
||||
max_len = max([sequence.shape[0] for sequence in sequences])
|
||||
padded_sequences = [
|
||||
F.pad(sequence, [0] * (sequence.ndim - 1) * 2 + [0, max_len - sequence.shape[0]]) for sequence in sequences
|
||||
F.pad(
|
||||
sequence, [0] * (sequence.ndim - 1) * 2 + [0, max_len - sequence.shape[0]]
|
||||
)
|
||||
for sequence in sequences
|
||||
]
|
||||
padded_sequences = torch.stack(padded_sequences, dim=0)
|
||||
padding_mask = torch.zeros(
|
||||
|
|
@ -85,7 +93,9 @@ def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Te
|
|||
return padded_sequences, padding_mask
|
||||
|
||||
|
||||
def patchify_batch(videos: List[torch.Tensor], patch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def patchify_batch(
|
||||
videos: List[torch.Tensor], patch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Patchify a batch of videos.
|
||||
|
||||
Args:
|
||||
|
|
@ -115,7 +125,7 @@ def expand_mask_4d(q_mask: torch.Tensor, kv_mask: torch.Tensor) -> torch.Tensor:
|
|||
return mask.unsqueeze(1)
|
||||
|
||||
|
||||
def make_batch(samples: List[dict], patch_size: int) -> dict:
|
||||
def make_batch(samples: List[dict], video_dir: str) -> dict:
|
||||
"""Make a batch of samples.
|
||||
|
||||
Args:
|
||||
|
|
@ -124,27 +134,67 @@ def make_batch(samples: List[dict], patch_size: int) -> dict:
|
|||
Returns:
|
||||
dict: A batch of samples.
|
||||
"""
|
||||
videos = [sample["video_latent_states"] for sample in samples]
|
||||
videos, video_padding_mask = patchify_batch(videos, patch_size)
|
||||
videos = [
|
||||
read_video(os.path.join(video_dir, sample["video_file"]), pts_unit="sec")[0]
|
||||
for sample in samples
|
||||
]
|
||||
texts = [sample["text_latent_states"] for sample in samples]
|
||||
texts, text_padding_mask = pad_sequences(texts)
|
||||
return {
|
||||
"video_latent_states": videos,
|
||||
"video_padding_mask": video_padding_mask,
|
||||
"videos": videos,
|
||||
"text_latent_states": texts,
|
||||
"text_padding_mask": text_padding_mask,
|
||||
"attention_mask": expand_mask_4d(video_padding_mask, text_padding_mask),
|
||||
}
|
||||
|
||||
|
||||
def load_datasets(dataset_paths: Union[PathType, List[PathType]], mode: str = "train") -> Optional[DatasetType]:
|
||||
def normalize_video(video: torch.Tensor) -> torch.Tensor:
|
||||
return video.float() / 255 - 0.5
|
||||
|
||||
|
||||
def unnormalize_video(video: torch.Tensor) -> torch.Tensor:
|
||||
return (video + 0.5) * 255
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def preprocess_batch(
|
||||
batch: dict, patch_size: int, vqvae: nn.Module, device=None
|
||||
) -> dict:
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
videos = []
|
||||
for video in batch.pop("videos"):
|
||||
video = video.to(device)
|
||||
video = normalize_video(video)
|
||||
# [T, H, W, C] -> [B, C, T, H, W]
|
||||
video = video.permute(3, 0, 1, 2)
|
||||
video = video.unsqueeze(0)
|
||||
latent_indices, embeddings = vqvae.encode(video, include_embeddings=True)
|
||||
# [B, C, T, H, W] -> [T, C, H, W]
|
||||
embeddings = embeddings.squeeze(0).permute(1, 0, 2, 3)
|
||||
videos.append(embeddings)
|
||||
video_latent_states, video_padding_mask = patchify_batch(videos, patch_size)
|
||||
# hack diffuser, [B, S, C, P, P] -> [B, C, S, P, P]
|
||||
video_latent_states = video_latent_states.transpose(1, 2)
|
||||
batch["video_latent_states"] = video_latent_states
|
||||
batch["video_padding_mask"] = video_padding_mask
|
||||
text_padding_mask = batch.pop("text_padding_mask").to(device)
|
||||
batch["attention_mask"] = expand_mask_4d(video_padding_mask, text_padding_mask)
|
||||
batch["text_latent_states"] = batch["text_latent_states"].to(device)
|
||||
return batch
|
||||
|
||||
|
||||
def load_datasets(
|
||||
dataset_paths: Union[PathType, List[PathType]], mode: str = "train"
|
||||
) -> Optional[DatasetType]:
|
||||
"""
|
||||
Load pre-tokenized dataset.
|
||||
Each instance of dataset is a dictionary with
|
||||
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
|
||||
"""
|
||||
mode_map = {"train": "train", "dev": "validation", "test": "test"}
|
||||
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
|
||||
assert mode in tuple(
|
||||
mode_map
|
||||
), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
|
||||
|
||||
if isinstance(dataset_paths, (str, os.PathLike)):
|
||||
dataset_paths = [dataset_paths]
|
||||
|
|
@ -153,7 +203,9 @@ def load_datasets(dataset_paths: Union[PathType, List[PathType]], mode: str = "t
|
|||
for ds_path in dataset_paths:
|
||||
ds_path = os.path.abspath(ds_path)
|
||||
assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
|
||||
ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False).with_format("torch")
|
||||
ds_dict = load_from_disk(
|
||||
dataset_path=ds_path, keep_in_memory=False
|
||||
).with_format("torch")
|
||||
if isinstance(ds_dict, HFDataset):
|
||||
datasets.append(ds_dict)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -210,7 +210,8 @@ class PatchEmbedder(nn.Module):
|
|||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# [B, S, C, P, P] -> [B, S, C*P*P]
|
||||
x = x.view(*x.shape[:2], -1)
|
||||
# FIXME: hack diffusion and use view
|
||||
x = x.reshape(*x.shape[:2], -1)
|
||||
out = F.linear(
|
||||
x, self.proj.weight.view(self.proj.weight.shape[0], -1), self.proj.bias
|
||||
)
|
||||
|
|
@ -327,7 +328,7 @@ class DiT(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
patch_size=2,
|
||||
in_channels=1,
|
||||
in_channels=256,
|
||||
text_embed_dim=512,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
|
|
@ -400,7 +401,6 @@ class DiT(nn.Module):
|
|||
t,
|
||||
text_latent_states=None,
|
||||
attention_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
video_latent_states: [B, C, S, P, P]
|
||||
|
|
|
|||
|
|
@ -1,33 +1,16 @@
|
|||
import argparse
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torchvision.io import read_video
|
||||
from transformers import AutoModel, AutoTokenizer, CLIPTextModel
|
||||
from transformers import AutoTokenizer, CLIPTextModel
|
||||
|
||||
EMPTY_SAMPLE = {"video_file": [], "video_latent_states": [], "text_latent_states": []}
|
||||
|
||||
|
||||
def preprocess_video(video):
|
||||
# [T, H, W, C] to [C, T, H, W]
|
||||
video = video.permute(3, 0, 1, 2)
|
||||
video = video.to(dtype=torch.float, device="cuda")
|
||||
# normalize
|
||||
video = video / 255 - 0.5
|
||||
return video.unsqueeze(0)
|
||||
|
||||
|
||||
def process_video(video_path, vqvae):
|
||||
video = read_video(video_path, pts_unit="sec")[0]
|
||||
video = preprocess_video(video)
|
||||
if video.size(2) > 600:
|
||||
raise ValueError("Video is too long")
|
||||
latent_states = vqvae.encode(video)
|
||||
return latent_states.squeeze(0).tolist()
|
||||
|
||||
|
||||
def process_text(text, tokenizer, text_model):
|
||||
inputs = tokenizer(text, padding=True, return_tensors="pt")
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
|
|
@ -35,30 +18,27 @@ def process_text(text, tokenizer, text_model):
|
|||
output_states = []
|
||||
for i, x in enumerate(outputs.last_hidden_state):
|
||||
valid_x = x[inputs["attention_mask"][i].bool()]
|
||||
output_states.append(valid_x.tolist())
|
||||
output_states.append(valid_x.cpu())
|
||||
return output_states
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def process_item(item, video_dir, tokenizer, text_model, vqvae):
|
||||
def process_item(item, video_dir, tokenizer, text_model):
|
||||
video_path = os.path.join(video_dir, item["file"])
|
||||
try:
|
||||
video_latent_states = process_video(video_path, vqvae)
|
||||
except ValueError:
|
||||
video = read_video(video_path, pts_unit="sec")[0]
|
||||
if video.size(0) > 600:
|
||||
return EMPTY_SAMPLE
|
||||
torch.cuda.empty_cache()
|
||||
text_latent_states = process_text(item["captions"], tokenizer, text_model)
|
||||
torch.cuda.empty_cache()
|
||||
return {
|
||||
"video_file": [item["file"]] * len(text_latent_states),
|
||||
"video_latent_states": [video_latent_states] * len(text_latent_states),
|
||||
"text_latent_states": text_latent_states,
|
||||
}
|
||||
|
||||
|
||||
def process_batch(batch, video_dir, tokenizer, text_model, vqvae):
|
||||
def process_batch(batch, video_dir, tokenizer, text_model):
|
||||
item = {"file": batch["file"][0], "captions": batch["captions"][0]}
|
||||
return process_item(item, video_dir, tokenizer, text_model, vqvae)
|
||||
return process_item(item, video_dir, tokenizer, text_model)
|
||||
|
||||
|
||||
def process_dataset(
|
||||
|
|
@ -67,11 +47,9 @@ def process_dataset(
|
|||
output_dir,
|
||||
num_spliced_dataset_bins=10,
|
||||
text_model="openai/clip-vit-base-patch32",
|
||||
vae_model="hpcai-tech/vqvae",
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_model)
|
||||
text_model = CLIPTextModel.from_pretrained(text_model).cuda().eval()
|
||||
vqvae = AutoModel.from_pretrained(vae_model, trust_remote_code=True).cuda().eval()
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
|
@ -86,12 +64,24 @@ def process_dataset(
|
|||
end = 100
|
||||
train_splits.append(f"train[{start}%:{end}%]")
|
||||
|
||||
ds = load_dataset("json", data_files=captions_file, keep_in_memory=False, split=train_splits)
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
data_files=captions_file,
|
||||
keep_in_memory=False,
|
||||
split=train_splits,
|
||||
num_proc=1,
|
||||
)
|
||||
|
||||
for i, part_ds in enumerate(ds):
|
||||
print(f"Processing part {i+1}/{len(ds)}")
|
||||
part_ds = part_ds.with_format("torch")
|
||||
part_ds = part_ds.map(
|
||||
process_batch,
|
||||
fn_kwargs={"video_dir": video_dir, "tokenizer": tokenizer, "text_model": text_model, "vqvae": vqvae},
|
||||
fn_kwargs={
|
||||
"video_dir": video_dir,
|
||||
"tokenizer": tokenizer,
|
||||
"text_model": text_model,
|
||||
},
|
||||
batched=True,
|
||||
batch_size=1,
|
||||
keep_in_memory=False,
|
||||
|
|
@ -104,15 +94,30 @@ def process_dataset(
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Preprocess data")
|
||||
parser.add_argument(
|
||||
"-c", "--captions-file", type=str, help="Path to the captions file. It should be a JSON file or a JSONL file"
|
||||
"-c",
|
||||
"--captions-file",
|
||||
type=str,
|
||||
help="Path to the captions file. It should be a JSON file or a JSONL file",
|
||||
)
|
||||
parser.add_argument("-v", "--video-dir", type=str, help="Path to the video directory")
|
||||
parser.add_argument("-o", "--output_dir", type=str, help="Path to the output directory")
|
||||
parser.add_argument(
|
||||
"-n", "--num_spliced_dataset_bins", type=int, default=10, help="Number of bins for spliced dataset"
|
||||
"-v", "--video-dir", type=str, help="Path to the video directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, help="Path to the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--num_spliced_dataset_bins",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of bins for spliced dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_model",
|
||||
type=str,
|
||||
default="openai/clip-vit-base-patch32",
|
||||
help="CLIP text model",
|
||||
)
|
||||
parser.add_argument("--text_model", type=str, default="openai/clip-vit-base-patch32", help="CLIP text model")
|
||||
parser.add_argument("--vae_model", type=str, default="hpcai-tech/vqvae", help="VQ-VAE model")
|
||||
args = parser.parse_args()
|
||||
process_dataset(
|
||||
args.captions_file,
|
||||
|
|
@ -120,5 +125,4 @@ if __name__ == "__main__":
|
|||
args.output_dir,
|
||||
args.num_spliced_dataset_bins,
|
||||
args.text_model,
|
||||
args.vae_model,
|
||||
)
|
||||
|
|
|
|||
28
train.py
28
train.py
|
|
@ -26,8 +26,9 @@ from colossalai.cluster import DistCoordinator
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModel
|
||||
|
||||
from data_utils import load_datasets, make_batch
|
||||
from data_utils import load_datasets, make_batch, preprocess_batch
|
||||
from diffusion import create_diffusion
|
||||
from models import DiT_models
|
||||
|
||||
|
|
@ -54,19 +55,6 @@ def requires_grad(model, flag=True):
|
|||
p.requires_grad = flag
|
||||
|
||||
|
||||
def collate_fn(batch, patch_size=2):
|
||||
for sample in batch:
|
||||
video = sample["video_latent_states"]
|
||||
# [T, H, W] -> [T, C, H, W]
|
||||
video = video.unsqueeze(1)
|
||||
video = video.float() * 0.18215
|
||||
sample["video_latent_states"] = video
|
||||
batch = make_batch(batch, patch_size)
|
||||
# hack diffuser, [B, S, C, P, P] -> [B, C, S, P, P]
|
||||
batch["video_latent_states"] = batch["video_latent_states"].transpose(1, 2)
|
||||
return batch
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
|
|
@ -92,7 +80,13 @@ def main(args):
|
|||
os.makedirs(args.checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Setup model
|
||||
vqvae = (
|
||||
AutoModel.from_pretrained(args.vqvae, trust_remote_code=True)
|
||||
.to(get_current_device())
|
||||
.eval()
|
||||
)
|
||||
model = DiT_models[args.model]().to(get_current_device())
|
||||
patch_size = model.patch_size
|
||||
ema = deepcopy(model)
|
||||
requires_grad(ema, False)
|
||||
model.train() # important! This enables embedding dropout for classifier-free guidance
|
||||
|
|
@ -112,7 +106,7 @@ def main(args):
|
|||
dataloader = plugin.prepare_dataloader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=partial(collate_fn, patch_size=model.patch_size),
|
||||
collate_fn=partial(make_batch, video_dir=args.video_dir),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
|
@ -133,7 +127,7 @@ def main(args):
|
|||
total=len(dataloader),
|
||||
) as pbar:
|
||||
for step, batch in enumerate(dataloader):
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items()}
|
||||
batch = preprocess_batch(batch, patch_size, vqvae)
|
||||
video_inputs = batch.pop("video_latent_states")
|
||||
mask = batch.pop("video_padding_mask")
|
||||
t = torch.randint(
|
||||
|
|
@ -192,10 +186,12 @@ if __name__ == "__main__":
|
|||
"-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
|
||||
)
|
||||
parser.add_argument("-d", "--dataset", nargs="+", default=[])
|
||||
parser.add_argument("-v", "--video_dir", type=str, required=True)
|
||||
parser.add_argument("-e", "--epochs", type=int, default=10)
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=4)
|
||||
parser.add_argument("-g", "--grad_checkpoint", action="store_true", default=False)
|
||||
parser.add_argument("--save_interval", type=int, default=20)
|
||||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
|
||||
parser.add_argument("--vqvae", default="hpcai-tech/vqvae")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
|
|||
Loading…
Reference in a new issue