mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
[feat] complete training with feature
This commit is contained in:
parent
066b0c9bb3
commit
a50bec768f
|
|
@ -15,6 +15,7 @@ model = dict(
|
|||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
freeze_y_embedder=True,
|
||||
skip_y_embedder=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="rflow",
|
||||
|
|
@ -22,6 +23,12 @@ scheduler = dict(
|
|||
sample_method="logit-normal",
|
||||
)
|
||||
|
||||
vae_out_channels = 4
|
||||
model_max_length = 300
|
||||
text_encoder_output_dim = 1152
|
||||
load_video_features = True
|
||||
load_text_features = True
|
||||
|
||||
# Mask settings
|
||||
mask_ratios = {
|
||||
"random": 0.2,
|
||||
|
|
@ -42,7 +49,7 @@ outputs = "outputs"
|
|||
wandb = False
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 1
|
||||
ckpt_every = 500
|
||||
|
||||
# optimization settings
|
||||
load = None
|
||||
|
|
|
|||
|
|
@ -89,15 +89,18 @@ def prepare_dataloader(
|
|||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
sampler=sampler,
|
||||
worker_init_fn=get_seed_worker(seed),
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_batch,
|
||||
**_kwargs,
|
||||
return (
|
||||
DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
sampler=sampler,
|
||||
worker_init_fn=get_seed_worker(seed),
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_batch,
|
||||
**_kwargs,
|
||||
),
|
||||
sampler,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|
||||
|
|
|
|||
|
|
@ -230,4 +230,14 @@ class BatchFeatureDataset(torch.utils.data.Dataset):
|
|||
self._load_buffer(idx)
|
||||
|
||||
batch = self.cur_buffer[idx % self.len_buffer] # dict; keys are {'x', 'fps'} and text related
|
||||
return batch
|
||||
|
||||
ret = {
|
||||
"video": batch["x"],
|
||||
"text": batch["y"],
|
||||
"mask": batch["mask"],
|
||||
"fps": batch["fps"],
|
||||
"height": batch["height"],
|
||||
"width": batch["width"],
|
||||
"num_frames": batch["num_frames"],
|
||||
}
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -288,7 +288,7 @@ class VariableVideoBatchSampler(DistributedSampler):
|
|||
|
||||
class BatchDistributedSampler(DistributedSampler):
|
||||
"""
|
||||
Used with BatchFeatureDataset;
|
||||
Used with BatchDataset;
|
||||
Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
|
||||
| buffer {i} | buffer {i+1}
|
||||
------ | ------------------- | -------------------
|
||||
|
|
@ -297,13 +297,26 @@ class BatchDistributedSampler(DistributedSampler):
|
|||
rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset, **kwargs):
|
||||
super().__init__(dataset, **kwargs)
|
||||
self.start_index = 0
|
||||
|
||||
def __iter__(self):
|
||||
num_buffers = self.dataset.num_buffers
|
||||
len_buffer = self.dataset.len_buffer
|
||||
num_buffers_i = num_buffers // self.num_replicas
|
||||
num_samples_i = len_buffer * num_buffers_i
|
||||
|
||||
indices_i = np.arange(num_samples_i) + self.rank * num_samples_i
|
||||
indices_i = np.arange(self.start_index, num_samples_i) + self.rank * num_samples_i
|
||||
indices_i = indices_i.tolist()
|
||||
|
||||
return iter(indices_i)
|
||||
|
||||
def reset(self):
|
||||
self.start_index = 0
|
||||
|
||||
def state_dict(self, step) -> dict:
|
||||
return {"start_index": step}
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
self.start_index = state_dict["start_index"] + 1
|
||||
|
|
|
|||
|
|
@ -371,6 +371,8 @@ class STDiT3(PreTrainedModel):
|
|||
# === get y embed ===
|
||||
if self.config.skip_y_embedder:
|
||||
y_lens = mask
|
||||
if isinstance(y_lens, torch.Tensor):
|
||||
y_lens = y_lens.long().tolist()
|
||||
else:
|
||||
y, y_lens = self.encode_text(y, mask)
|
||||
|
||||
|
|
|
|||
|
|
@ -144,8 +144,15 @@ def main():
|
|||
x = vae.encode(x)
|
||||
with Timer("feature to cpu", log=log_time):
|
||||
x = x.cpu()
|
||||
fps = batch["fps"].to(dtype)
|
||||
batch_dict = {"x": x, "fps": fps}
|
||||
|
||||
batch_dict = {
|
||||
"x": x,
|
||||
"text": y,
|
||||
"fps": batch["fps"].to(dtype),
|
||||
"height": batch["height"].to(dtype),
|
||||
"width": batch["width"].to(dtype),
|
||||
"num_frames": batch["num_frames"].to(dtype),
|
||||
}
|
||||
|
||||
if save_text_features:
|
||||
with Timer("text", log=log_time):
|
||||
|
|
|
|||
|
|
@ -109,22 +109,34 @@ def main():
|
|||
logger.info("Building models...")
|
||||
# == build text-encoder and vae ==
|
||||
text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype)
|
||||
vae = build_module(cfg.get("vae", None), MODELS).to(device, dtype).eval()
|
||||
if text_encoder is not None:
|
||||
text_encoder_output_dim = text_encoder.output_dim
|
||||
text_encoder_model_max_length = text_encoder.model_max_length
|
||||
else:
|
||||
text_encoder_output_dim = cfg.get("text_encoder_output_dim", 1152)
|
||||
text_encoder_model_max_length = cfg.get("text_encoder_model_max_length", 300)
|
||||
|
||||
# == build diffusion model ==
|
||||
input_size = (dataset.num_frames, *dataset.image_size)
|
||||
# == build vae ==
|
||||
vae = build_module(cfg.get("vae", None), MODELS)
|
||||
if vae is not None:
|
||||
vae = vae.to(device, dtype).eval()
|
||||
if vae is not None:
|
||||
input_size = (dataset.num_frames, *dataset.image_size)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
vae_out_channels = vae.out_channels
|
||||
else:
|
||||
latent_size = (None, None, None)
|
||||
vae_out_channels = cfg.get("vae_out_channels", 4)
|
||||
|
||||
# == build diffusion model ==
|
||||
model = (
|
||||
build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
input_size=latent_size,
|
||||
in_channels=vae.out_channels,
|
||||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
in_channels=vae_out_channels,
|
||||
caption_channels=text_encoder_output_dim,
|
||||
model_max_length=text_encoder_model_max_length,
|
||||
)
|
||||
.to(device, dtype)
|
||||
.train()
|
||||
|
|
@ -223,7 +235,6 @@ def main():
|
|||
for step, batch in pbar:
|
||||
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
||||
y = batch.pop("text")
|
||||
breakpoint()
|
||||
|
||||
# == visual and text encoding ==
|
||||
with torch.no_grad():
|
||||
|
|
|
|||
Loading…
Reference in a new issue