From a50bec768f8a8d9fc447bd25d70c91bafc403eaa Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Tue, 21 May 2024 07:20:14 +0000 Subject: [PATCH] [feat] complete training with feature --- configs/opensora-v1-2/train/stage1_feat.py | 9 +++++++- opensora/datasets/dataloader.py | 21 ++++++++++-------- opensora/datasets/datasets.py | 12 ++++++++++- opensora/datasets/sampler.py | 17 +++++++++++++-- opensora/models/stdit/stdit3.py | 2 ++ scripts/misc/extract_feat.py | 11 ++++++++-- scripts/train.py | 25 ++++++++++++++++------ 7 files changed, 75 insertions(+), 22 deletions(-) diff --git a/configs/opensora-v1-2/train/stage1_feat.py b/configs/opensora-v1-2/train/stage1_feat.py index ce96d30..d82d4f6 100644 --- a/configs/opensora-v1-2/train/stage1_feat.py +++ b/configs/opensora-v1-2/train/stage1_feat.py @@ -15,6 +15,7 @@ model = dict( enable_flash_attn=True, enable_layernorm_kernel=True, freeze_y_embedder=True, + skip_y_embedder=True, ) scheduler = dict( type="rflow", @@ -22,6 +23,12 @@ scheduler = dict( sample_method="logit-normal", ) +vae_out_channels = 4 +model_max_length = 300 +text_encoder_output_dim = 1152 +load_video_features = True +load_text_features = True + # Mask settings mask_ratios = { "random": 0.2, @@ -42,7 +49,7 @@ outputs = "outputs" wandb = False epochs = 1000 log_every = 10 -ckpt_every = 1 +ckpt_every = 500 # optimization settings load = None diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py index 8516951..8bcaed9 100644 --- a/opensora/datasets/dataloader.py +++ b/opensora/datasets/dataloader.py @@ -89,15 +89,18 @@ def prepare_dataloader( num_replicas=process_group.size(), rank=process_group.rank(), ) - return DataLoader( - dataset, - batch_size=1, - sampler=sampler, - worker_init_fn=get_seed_worker(seed), - pin_memory=pin_memory, - num_workers=num_workers, - collate_fn=collate_fn_batch, - **_kwargs, + return ( + DataLoader( + dataset, + batch_size=1, + sampler=sampler, + worker_init_fn=get_seed_worker(seed), + pin_memory=pin_memory, + num_workers=num_workers, + collate_fn=collate_fn_batch, + **_kwargs, + ), + sampler, ) else: raise ValueError(f"Unsupported dataset type: {type(dataset)}") diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index 7050c98..6f98ce3 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -230,4 +230,14 @@ class BatchFeatureDataset(torch.utils.data.Dataset): self._load_buffer(idx) batch = self.cur_buffer[idx % self.len_buffer] # dict; keys are {'x', 'fps'} and text related - return batch + + ret = { + "video": batch["x"], + "text": batch["y"], + "mask": batch["mask"], + "fps": batch["fps"], + "height": batch["height"], + "width": batch["width"], + "num_frames": batch["num_frames"], + } + return ret diff --git a/opensora/datasets/sampler.py b/opensora/datasets/sampler.py index 64c36e2..511579f 100644 --- a/opensora/datasets/sampler.py +++ b/opensora/datasets/sampler.py @@ -288,7 +288,7 @@ class VariableVideoBatchSampler(DistributedSampler): class BatchDistributedSampler(DistributedSampler): """ - Used with BatchFeatureDataset; + Used with BatchDataset; Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then | buffer {i} | buffer {i+1} ------ | ------------------- | ------------------- @@ -297,13 +297,26 @@ class BatchDistributedSampler(DistributedSampler): rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29 """ + def __init__(self, dataset: Dataset, **kwargs): + super().__init__(dataset, **kwargs) + self.start_index = 0 + def __iter__(self): num_buffers = self.dataset.num_buffers len_buffer = self.dataset.len_buffer num_buffers_i = num_buffers // self.num_replicas num_samples_i = len_buffer * num_buffers_i - indices_i = np.arange(num_samples_i) + self.rank * num_samples_i + indices_i = np.arange(self.start_index, num_samples_i) + self.rank * num_samples_i indices_i = indices_i.tolist() return iter(indices_i) + + def reset(self): + self.start_index = 0 + + def state_dict(self, step) -> dict: + return {"start_index": step} + + def load_state_dict(self, state_dict: dict): + self.start_index = state_dict["start_index"] + 1 diff --git a/opensora/models/stdit/stdit3.py b/opensora/models/stdit/stdit3.py index 68bb60e..a524932 100644 --- a/opensora/models/stdit/stdit3.py +++ b/opensora/models/stdit/stdit3.py @@ -371,6 +371,8 @@ class STDiT3(PreTrainedModel): # === get y embed === if self.config.skip_y_embedder: y_lens = mask + if isinstance(y_lens, torch.Tensor): + y_lens = y_lens.long().tolist() else: y, y_lens = self.encode_text(y, mask) diff --git a/scripts/misc/extract_feat.py b/scripts/misc/extract_feat.py index 47fee72..bddfa77 100644 --- a/scripts/misc/extract_feat.py +++ b/scripts/misc/extract_feat.py @@ -144,8 +144,15 @@ def main(): x = vae.encode(x) with Timer("feature to cpu", log=log_time): x = x.cpu() - fps = batch["fps"].to(dtype) - batch_dict = {"x": x, "fps": fps} + + batch_dict = { + "x": x, + "text": y, + "fps": batch["fps"].to(dtype), + "height": batch["height"].to(dtype), + "width": batch["width"].to(dtype), + "num_frames": batch["num_frames"].to(dtype), + } if save_text_features: with Timer("text", log=log_time): diff --git a/scripts/train.py b/scripts/train.py index 1818782..b1862ee 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -109,22 +109,34 @@ def main(): logger.info("Building models...") # == build text-encoder and vae == text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype) - vae = build_module(cfg.get("vae", None), MODELS).to(device, dtype).eval() + if text_encoder is not None: + text_encoder_output_dim = text_encoder.output_dim + text_encoder_model_max_length = text_encoder.model_max_length + else: + text_encoder_output_dim = cfg.get("text_encoder_output_dim", 1152) + text_encoder_model_max_length = cfg.get("text_encoder_model_max_length", 300) - # == build diffusion model == - input_size = (dataset.num_frames, *dataset.image_size) + # == build vae == + vae = build_module(cfg.get("vae", None), MODELS) if vae is not None: + vae = vae.to(device, dtype).eval() + if vae is not None: + input_size = (dataset.num_frames, *dataset.image_size) latent_size = vae.get_latent_size(input_size) + vae_out_channels = vae.out_channels else: latent_size = (None, None, None) + vae_out_channels = cfg.get("vae_out_channels", 4) + + # == build diffusion model == model = ( build_module( cfg.model, MODELS, input_size=latent_size, - in_channels=vae.out_channels, - caption_channels=text_encoder.output_dim, - model_max_length=text_encoder.model_max_length, + in_channels=vae_out_channels, + caption_channels=text_encoder_output_dim, + model_max_length=text_encoder_model_max_length, ) .to(device, dtype) .train() @@ -223,7 +235,6 @@ def main(): for step, batch in pbar: x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] y = batch.pop("text") - breakpoint() # == visual and text encoding == with torch.no_grad():