[feat] complete training with feature

This commit is contained in:
zhengzangw 2024-05-21 07:20:14 +00:00
parent 066b0c9bb3
commit a50bec768f
7 changed files with 75 additions and 22 deletions

View file

@ -15,6 +15,7 @@ model = dict(
enable_flash_attn=True,
enable_layernorm_kernel=True,
freeze_y_embedder=True,
skip_y_embedder=True,
)
scheduler = dict(
type="rflow",
@ -22,6 +23,12 @@ scheduler = dict(
sample_method="logit-normal",
)
vae_out_channels = 4
model_max_length = 300
text_encoder_output_dim = 1152
load_video_features = True
load_text_features = True
# Mask settings
mask_ratios = {
"random": 0.2,
@ -42,7 +49,7 @@ outputs = "outputs"
wandb = False
epochs = 1000
log_every = 10
ckpt_every = 1
ckpt_every = 500
# optimization settings
load = None

View file

@ -89,15 +89,18 @@ def prepare_dataloader(
num_replicas=process_group.size(),
rank=process_group.rank(),
)
return DataLoader(
dataset,
batch_size=1,
sampler=sampler,
worker_init_fn=get_seed_worker(seed),
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_batch,
**_kwargs,
return (
DataLoader(
dataset,
batch_size=1,
sampler=sampler,
worker_init_fn=get_seed_worker(seed),
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_batch,
**_kwargs,
),
sampler,
)
else:
raise ValueError(f"Unsupported dataset type: {type(dataset)}")

View file

@ -230,4 +230,14 @@ class BatchFeatureDataset(torch.utils.data.Dataset):
self._load_buffer(idx)
batch = self.cur_buffer[idx % self.len_buffer] # dict; keys are {'x', 'fps'} and text related
return batch
ret = {
"video": batch["x"],
"text": batch["y"],
"mask": batch["mask"],
"fps": batch["fps"],
"height": batch["height"],
"width": batch["width"],
"num_frames": batch["num_frames"],
}
return ret

View file

@ -288,7 +288,7 @@ class VariableVideoBatchSampler(DistributedSampler):
class BatchDistributedSampler(DistributedSampler):
"""
Used with BatchFeatureDataset;
Used with BatchDataset;
Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
| buffer {i} | buffer {i+1}
------ | ------------------- | -------------------
@ -297,13 +297,26 @@ class BatchDistributedSampler(DistributedSampler):
rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
"""
def __init__(self, dataset: Dataset, **kwargs):
super().__init__(dataset, **kwargs)
self.start_index = 0
def __iter__(self):
num_buffers = self.dataset.num_buffers
len_buffer = self.dataset.len_buffer
num_buffers_i = num_buffers // self.num_replicas
num_samples_i = len_buffer * num_buffers_i
indices_i = np.arange(num_samples_i) + self.rank * num_samples_i
indices_i = np.arange(self.start_index, num_samples_i) + self.rank * num_samples_i
indices_i = indices_i.tolist()
return iter(indices_i)
def reset(self):
self.start_index = 0
def state_dict(self, step) -> dict:
return {"start_index": step}
def load_state_dict(self, state_dict: dict):
self.start_index = state_dict["start_index"] + 1

View file

@ -371,6 +371,8 @@ class STDiT3(PreTrainedModel):
# === get y embed ===
if self.config.skip_y_embedder:
y_lens = mask
if isinstance(y_lens, torch.Tensor):
y_lens = y_lens.long().tolist()
else:
y, y_lens = self.encode_text(y, mask)

View file

@ -144,8 +144,15 @@ def main():
x = vae.encode(x)
with Timer("feature to cpu", log=log_time):
x = x.cpu()
fps = batch["fps"].to(dtype)
batch_dict = {"x": x, "fps": fps}
batch_dict = {
"x": x,
"text": y,
"fps": batch["fps"].to(dtype),
"height": batch["height"].to(dtype),
"width": batch["width"].to(dtype),
"num_frames": batch["num_frames"].to(dtype),
}
if save_text_features:
with Timer("text", log=log_time):

View file

@ -109,22 +109,34 @@ def main():
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype)
vae = build_module(cfg.get("vae", None), MODELS).to(device, dtype).eval()
if text_encoder is not None:
text_encoder_output_dim = text_encoder.output_dim
text_encoder_model_max_length = text_encoder.model_max_length
else:
text_encoder_output_dim = cfg.get("text_encoder_output_dim", 1152)
text_encoder_model_max_length = cfg.get("text_encoder_model_max_length", 300)
# == build diffusion model ==
input_size = (dataset.num_frames, *dataset.image_size)
# == build vae ==
vae = build_module(cfg.get("vae", None), MODELS)
if vae is not None:
vae = vae.to(device, dtype).eval()
if vae is not None:
input_size = (dataset.num_frames, *dataset.image_size)
latent_size = vae.get_latent_size(input_size)
vae_out_channels = vae.out_channels
else:
latent_size = (None, None, None)
vae_out_channels = cfg.get("vae_out_channels", 4)
# == build diffusion model ==
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
in_channels=vae_out_channels,
caption_channels=text_encoder_output_dim,
model_max_length=text_encoder_model_max_length,
)
.to(device, dtype)
.train()
@ -223,7 +235,6 @@ def main():
for step, batch in pbar:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
breakpoint()
# == visual and text encoding ==
with torch.no_grad():