[hotfix] fix vqvae output process (#9)

This commit is contained in:
Hongxin Liu 2024-02-23 13:11:24 +08:00 committed by GitHub
parent da9b00e808
commit 6f887f453b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 122 additions and 70 deletions

View file

@ -2,10 +2,13 @@ import os
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.utils import get_current_device
from datasets import Dataset as HFDataset
from datasets import dataset_dict, load_from_disk
from torch.utils.data import ConcatDataset, Dataset
from torchvision.io import read_video
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]
@ -34,7 +37,9 @@ def video2col(video_4d: torch.Tensor, patch_size: int) -> torch.Tensor:
return torch.stack(out, dim=1).view(-1, c, patch_size, patch_size)
def col2video(patches: torch.Tensor, video_shape: Tuple[int, int, int, int]) -> torch.Tensor:
def col2video(
patches: torch.Tensor, video_shape: Tuple[int, int, int, int]
) -> torch.Tensor:
"""
Convert a 2D tensor of patches to a 4D video tensor.
@ -71,7 +76,10 @@ def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Te
"""
max_len = max([sequence.shape[0] for sequence in sequences])
padded_sequences = [
F.pad(sequence, [0] * (sequence.ndim - 1) * 2 + [0, max_len - sequence.shape[0]]) for sequence in sequences
F.pad(
sequence, [0] * (sequence.ndim - 1) * 2 + [0, max_len - sequence.shape[0]]
)
for sequence in sequences
]
padded_sequences = torch.stack(padded_sequences, dim=0)
padding_mask = torch.zeros(
@ -85,7 +93,9 @@ def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Te
return padded_sequences, padding_mask
def patchify_batch(videos: List[torch.Tensor], patch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
def patchify_batch(
videos: List[torch.Tensor], patch_size: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Patchify a batch of videos.
Args:
@ -115,7 +125,7 @@ def expand_mask_4d(q_mask: torch.Tensor, kv_mask: torch.Tensor) -> torch.Tensor:
return mask.unsqueeze(1)
def make_batch(samples: List[dict], patch_size: int) -> dict:
def make_batch(samples: List[dict], video_dir: str) -> dict:
"""Make a batch of samples.
Args:
@ -124,27 +134,67 @@ def make_batch(samples: List[dict], patch_size: int) -> dict:
Returns:
dict: A batch of samples.
"""
videos = [sample["video_latent_states"] for sample in samples]
videos, video_padding_mask = patchify_batch(videos, patch_size)
videos = [
read_video(os.path.join(video_dir, sample["video_file"]), pts_unit="sec")[0]
for sample in samples
]
texts = [sample["text_latent_states"] for sample in samples]
texts, text_padding_mask = pad_sequences(texts)
return {
"video_latent_states": videos,
"video_padding_mask": video_padding_mask,
"videos": videos,
"text_latent_states": texts,
"text_padding_mask": text_padding_mask,
"attention_mask": expand_mask_4d(video_padding_mask, text_padding_mask),
}
def load_datasets(dataset_paths: Union[PathType, List[PathType]], mode: str = "train") -> Optional[DatasetType]:
def normalize_video(video: torch.Tensor) -> torch.Tensor:
return video.float() / 255 - 0.5
def unnormalize_video(video: torch.Tensor) -> torch.Tensor:
return (video + 0.5) * 255
@torch.no_grad()
def preprocess_batch(
batch: dict, patch_size: int, vqvae: nn.Module, device=None
) -> dict:
if device is None:
device = get_current_device()
videos = []
for video in batch.pop("videos"):
video = video.to(device)
video = normalize_video(video)
# [T, H, W, C] -> [B, C, T, H, W]
video = video.permute(3, 0, 1, 2)
video = video.unsqueeze(0)
latent_indices, embeddings = vqvae.encode(video, include_embeddings=True)
# [B, C, T, H, W] -> [T, C, H, W]
embeddings = embeddings.squeeze(0).permute(1, 0, 2, 3)
videos.append(embeddings)
video_latent_states, video_padding_mask = patchify_batch(videos, patch_size)
# hack diffuser, [B, S, C, P, P] -> [B, C, S, P, P]
video_latent_states = video_latent_states.transpose(1, 2)
batch["video_latent_states"] = video_latent_states
batch["video_padding_mask"] = video_padding_mask
text_padding_mask = batch.pop("text_padding_mask").to(device)
batch["attention_mask"] = expand_mask_4d(video_padding_mask, text_padding_mask)
batch["text_latent_states"] = batch["text_latent_states"].to(device)
return batch
def load_datasets(
dataset_paths: Union[PathType, List[PathType]], mode: str = "train"
) -> Optional[DatasetType]:
"""
Load pre-tokenized dataset.
Each instance of dataset is a dictionary with
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
"""
mode_map = {"train": "train", "dev": "validation", "test": "test"}
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
assert mode in tuple(
mode_map
), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
if isinstance(dataset_paths, (str, os.PathLike)):
dataset_paths = [dataset_paths]
@ -153,7 +203,9 @@ def load_datasets(dataset_paths: Union[PathType, List[PathType]], mode: str = "t
for ds_path in dataset_paths:
ds_path = os.path.abspath(ds_path)
assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False).with_format("torch")
ds_dict = load_from_disk(
dataset_path=ds_path, keep_in_memory=False
).with_format("torch")
if isinstance(ds_dict, HFDataset):
datasets.append(ds_dict)
else:

View file

@ -210,7 +210,8 @@ class PatchEmbedder(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# [B, S, C, P, P] -> [B, S, C*P*P]
x = x.view(*x.shape[:2], -1)
# FIXME: hack diffusion and use view
x = x.reshape(*x.shape[:2], -1)
out = F.linear(
x, self.proj.weight.view(self.proj.weight.shape[0], -1), self.proj.bias
)
@ -327,7 +328,7 @@ class DiT(nn.Module):
def __init__(
self,
patch_size=2,
in_channels=1,
in_channels=256,
text_embed_dim=512,
hidden_size=1152,
depth=28,
@ -400,7 +401,6 @@ class DiT(nn.Module):
t,
text_latent_states=None,
attention_mask=None,
**kwargs,
):
"""
video_latent_states: [B, C, S, P, P]

View file

@ -1,33 +1,16 @@
import argparse
import io
import math
import os
import torch
from datasets import load_dataset
from torchvision.io import read_video
from transformers import AutoModel, AutoTokenizer, CLIPTextModel
from transformers import AutoTokenizer, CLIPTextModel
EMPTY_SAMPLE = {"video_file": [], "video_latent_states": [], "text_latent_states": []}
def preprocess_video(video):
# [T, H, W, C] to [C, T, H, W]
video = video.permute(3, 0, 1, 2)
video = video.to(dtype=torch.float, device="cuda")
# normalize
video = video / 255 - 0.5
return video.unsqueeze(0)
def process_video(video_path, vqvae):
video = read_video(video_path, pts_unit="sec")[0]
video = preprocess_video(video)
if video.size(2) > 600:
raise ValueError("Video is too long")
latent_states = vqvae.encode(video)
return latent_states.squeeze(0).tolist()
def process_text(text, tokenizer, text_model):
inputs = tokenizer(text, padding=True, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()}
@ -35,30 +18,27 @@ def process_text(text, tokenizer, text_model):
output_states = []
for i, x in enumerate(outputs.last_hidden_state):
valid_x = x[inputs["attention_mask"][i].bool()]
output_states.append(valid_x.tolist())
output_states.append(valid_x.cpu())
return output_states
@torch.no_grad()
def process_item(item, video_dir, tokenizer, text_model, vqvae):
def process_item(item, video_dir, tokenizer, text_model):
video_path = os.path.join(video_dir, item["file"])
try:
video_latent_states = process_video(video_path, vqvae)
except ValueError:
video = read_video(video_path, pts_unit="sec")[0]
if video.size(0) > 600:
return EMPTY_SAMPLE
torch.cuda.empty_cache()
text_latent_states = process_text(item["captions"], tokenizer, text_model)
torch.cuda.empty_cache()
return {
"video_file": [item["file"]] * len(text_latent_states),
"video_latent_states": [video_latent_states] * len(text_latent_states),
"text_latent_states": text_latent_states,
}
def process_batch(batch, video_dir, tokenizer, text_model, vqvae):
def process_batch(batch, video_dir, tokenizer, text_model):
item = {"file": batch["file"][0], "captions": batch["captions"][0]}
return process_item(item, video_dir, tokenizer, text_model, vqvae)
return process_item(item, video_dir, tokenizer, text_model)
def process_dataset(
@ -67,11 +47,9 @@ def process_dataset(
output_dir,
num_spliced_dataset_bins=10,
text_model="openai/clip-vit-base-patch32",
vae_model="hpcai-tech/vqvae",
):
tokenizer = AutoTokenizer.from_pretrained(text_model)
text_model = CLIPTextModel.from_pretrained(text_model).cuda().eval()
vqvae = AutoModel.from_pretrained(vae_model, trust_remote_code=True).cuda().eval()
if not os.path.exists(output_dir):
os.makedirs(output_dir)
@ -86,12 +64,24 @@ def process_dataset(
end = 100
train_splits.append(f"train[{start}%:{end}%]")
ds = load_dataset("json", data_files=captions_file, keep_in_memory=False, split=train_splits)
ds = load_dataset(
"json",
data_files=captions_file,
keep_in_memory=False,
split=train_splits,
num_proc=1,
)
for i, part_ds in enumerate(ds):
print(f"Processing part {i+1}/{len(ds)}")
part_ds = part_ds.with_format("torch")
part_ds = part_ds.map(
process_batch,
fn_kwargs={"video_dir": video_dir, "tokenizer": tokenizer, "text_model": text_model, "vqvae": vqvae},
fn_kwargs={
"video_dir": video_dir,
"tokenizer": tokenizer,
"text_model": text_model,
},
batched=True,
batch_size=1,
keep_in_memory=False,
@ -104,15 +94,30 @@ def process_dataset(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess data")
parser.add_argument(
"-c", "--captions-file", type=str, help="Path to the captions file. It should be a JSON file or a JSONL file"
"-c",
"--captions-file",
type=str,
help="Path to the captions file. It should be a JSON file or a JSONL file",
)
parser.add_argument("-v", "--video-dir", type=str, help="Path to the video directory")
parser.add_argument("-o", "--output_dir", type=str, help="Path to the output directory")
parser.add_argument(
"-n", "--num_spliced_dataset_bins", type=int, default=10, help="Number of bins for spliced dataset"
"-v", "--video-dir", type=str, help="Path to the video directory"
)
parser.add_argument(
"-o", "--output_dir", type=str, help="Path to the output directory"
)
parser.add_argument(
"-n",
"--num_spliced_dataset_bins",
type=int,
default=10,
help="Number of bins for spliced dataset",
)
parser.add_argument(
"--text_model",
type=str,
default="openai/clip-vit-base-patch32",
help="CLIP text model",
)
parser.add_argument("--text_model", type=str, default="openai/clip-vit-base-patch32", help="CLIP text model")
parser.add_argument("--vae_model", type=str, default="hpcai-tech/vqvae", help="VQ-VAE model")
args = parser.parse_args()
process_dataset(
args.captions_file,
@ -120,5 +125,4 @@ if __name__ == "__main__":
args.output_dir,
args.num_spliced_dataset_bins,
args.text_model,
args.vae_model,
)

View file

@ -26,8 +26,9 @@ from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from tqdm import tqdm
from transformers import AutoModel
from data_utils import load_datasets, make_batch
from data_utils import load_datasets, make_batch, preprocess_batch
from diffusion import create_diffusion
from models import DiT_models
@ -54,19 +55,6 @@ def requires_grad(model, flag=True):
p.requires_grad = flag
def collate_fn(batch, patch_size=2):
for sample in batch:
video = sample["video_latent_states"]
# [T, H, W] -> [T, C, H, W]
video = video.unsqueeze(1)
video = video.float() * 0.18215
sample["video_latent_states"] = video
batch = make_batch(batch, patch_size)
# hack diffuser, [B, S, C, P, P] -> [B, C, S, P, P]
batch["video_latent_states"] = batch["video_latent_states"].transpose(1, 2)
return batch
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
@ -92,7 +80,13 @@ def main(args):
os.makedirs(args.checkpoint_dir, exist_ok=True)
# Setup model
vqvae = (
AutoModel.from_pretrained(args.vqvae, trust_remote_code=True)
.to(get_current_device())
.eval()
)
model = DiT_models[args.model]().to(get_current_device())
patch_size = model.patch_size
ema = deepcopy(model)
requires_grad(ema, False)
model.train() # important! This enables embedding dropout for classifier-free guidance
@ -112,7 +106,7 @@ def main(args):
dataloader = plugin.prepare_dataloader(
dataset,
batch_size=args.batch_size,
collate_fn=partial(collate_fn, patch_size=model.patch_size),
collate_fn=partial(make_batch, video_dir=args.video_dir),
shuffle=True,
drop_last=True,
)
@ -133,7 +127,7 @@ def main(args):
total=len(dataloader),
) as pbar:
for step, batch in enumerate(dataloader):
batch = {k: v.to(get_current_device()) for k, v in batch.items()}
batch = preprocess_batch(batch, patch_size, vqvae)
video_inputs = batch.pop("video_latent_states")
mask = batch.pop("video_padding_mask")
t = torch.randint(
@ -192,10 +186,12 @@ if __name__ == "__main__":
"-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
)
parser.add_argument("-d", "--dataset", nargs="+", default=[])
parser.add_argument("-v", "--video_dir", type=str, required=True)
parser.add_argument("-e", "--epochs", type=int, default=10)
parser.add_argument("-b", "--batch_size", type=int, default=4)
parser.add_argument("-g", "--grad_checkpoint", action="store_true", default=False)
parser.add_argument("--save_interval", type=int, default=20)
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
parser.add_argument("--vqvae", default="hpcai-tech/vqvae")
args = parser.parse_args()
main(args)