[hotfix] fix sample and update training script (#15)

* [hotfix] fix sample

* [hotfix] fix sample
This commit is contained in:
Hongxin Liu 2024-02-28 15:01:53 +08:00 committed by GitHub
parent 14db4566e1
commit e02f6286de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 64 additions and 26 deletions

View file

@ -19,7 +19,7 @@ from transformers import AutoModel, AutoTokenizer, CLIPTextModel
from open_sora.diffusion import create_diffusion
from open_sora.modeling import DiT_models
from open_sora.utils.data import col2video
from open_sora.utils.data import col2video, unnormalize_video
def main(args):
@ -34,16 +34,20 @@ def main(args):
.eval()
)
in_channels = vqvae.embedding_dim
w_h_factor = 4
t_factor = 2
else:
# disable VQ-VAE if not provided, just use raw video frames
vqvae = None
in_channels = 3
w_h_factor = 1
t_factor = 1
text_model = CLIPTextModel.from_pretrained(args.text_model).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
model = DiT_models[args.model](in_channels=in_channels).to(device).eval()
patch_size = model.patch_size
# model.load_state_dict(torch.load(args.ckpt))
model.load_state_dict(torch.load(args.ckpt))
diffusion = create_diffusion(str(args.num_sampling_steps))
# Create sampling noise:
@ -54,9 +58,9 @@ def main(args):
num_frames = args.fps * args.sec
z = torch.randn(
1,
(args.height // patch_size // 4)
* (args.width // patch_size // 4)
* (num_frames // 2),
(args.height // patch_size // w_h_factor)
* (args.width // patch_size // w_h_factor)
* (num_frames // t_factor),
in_channels,
patch_size,
patch_size,
@ -87,7 +91,12 @@ def main(args):
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
samples = col2video(
samples.squeeze(),
(num_frames // 2, in_channels, args.height // 4, args.width // 4),
(
num_frames // t_factor,
in_channels,
args.height // w_h_factor,
args.width // w_h_factor,
),
)
if vqvae is not None:
# [T, C, H, W] -> [B, C, T, H, W]
@ -98,6 +107,7 @@ def main(args):
else:
# [T, C, H, W] -> [T, H, W, C]
samples = samples.permute(0, 2, 3, 1)
samples = unnormalize_video(samples).to(torch.uint8)
write_video("sample.mp4", samples.cpu(), args.fps)
@ -105,12 +115,12 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
"-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
)
parser.add_argument(
"--text",
type=str,
default="two ladies laughing by seeing some thing another lady throw dresses and keep it back by reverse motion",
default="a cartoon animals runs through an ice cave in a video game",
)
parser.add_argument("--cfg-scale", type=float, default=4.0)
parser.add_argument("--num-sampling-steps", type=int, default=250)

View file

@ -64,6 +64,18 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
return tensor
def save_checkpoints(booster, model, optimizer, ema, save_path, coordinator):
os.makedirs(save_path, exist_ok=True)
booster.save_model(model, os.path.join(save_path, "model"), shard=False)
booster.save_optimizer(optimizer, os.path.join(save_path, "optimizer"), shard=True)
if coordinator.is_master():
ema_state_dict = ema.state_dict()
for k, v in ema_state_dict.items():
ema_state_dict[k] = v.cpu()
torch.save(ema_state_dict, os.path.join(save_path, "ema.pt"))
dist.barrier()
#################################################################################
# Training Loop #
#################################################################################
@ -89,7 +101,11 @@ def main(args):
# Step 3: Create VQ-VAE
if len(args.vqvae) > 0:
vqvae = AutoModel.from_pretrained(args.vqvae, trust_remote_code=True).to(get_current_device()).eval()
vqvae = (
AutoModel.from_pretrained(args.vqvae, trust_remote_code=True)
.to(get_current_device())
.eval()
)
model_kwargs = {"in_channels": vqvae.embedding_dim}
else:
# disable VQ-VAE if not provided, just use raw video frames
@ -110,7 +126,9 @@ def main(args):
model.enable_gradient_checkpointing()
# Step 5: create diffusion pipeline
diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
diffusion = create_diffusion(
timestep_respacing=""
) # default: 1000 steps, linear noise schedule
# Step 6: setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0)
@ -128,6 +146,10 @@ def main(args):
# Step 8: setup booster
model, opt, _, dataloader, _ = booster.boost(model, opt, dataloader=dataloader)
if args.load_model is not None:
booster.load_model(model, args.load_model)
if args.load_optimizer is not None:
booster.load_optimizer(opt, args.load_optimizer)
logger.info(
f"Booster init max device memory: {get_accelerator().max_memory_allocated() / 1024 ** 2:.2f} MB",
ranks=[0],
@ -154,7 +176,9 @@ def main(args):
(video_inputs.shape[0],),
device=video_inputs.device,
)
loss_dict = diffusion.training_losses(model, video_inputs, t, batch, mask=mask)
loss_dict = diffusion.training_losses(
model, video_inputs, t, batch, mask=mask
)
loss = loss_dict["loss"].mean() / args.accumulation_steps
total_loss.add_(loss.data)
booster.backward(loss, opt)
@ -167,7 +191,9 @@ def main(args):
all_reduce_mean(total_loss)
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
if coordinator.is_master():
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
global_step = (epoch * num_steps_per_epoch) + (
step + 1
) // args.accumulation_steps
writer.add_scalar(
tag="Loss",
scalar_value=total_loss.item(),
@ -177,22 +203,20 @@ def main(args):
total_loss.zero_()
# Save DiT checkpoint:
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
step + 1
) == len(dataloader):
save_path = os.path.join(args.checkpoint_dir, f"epoch-{epoch}-step-{step}")
os.makedirs(save_path, exist_ok=True)
booster.save_model(model, os.path.join(save_path, "model"), shard=True)
booster.save_optimizer(opt, os.path.join(save_path, "optimizer"), shard=True)
if coordinator.is_master():
ema_state_dict = ema.state_dict()
for k, v in ema_state_dict.items():
ema_state_dict[k] = v.cpu()
torch.save(ema_state_dict, os.path.join(save_path, "ema.pt"))
dist.barrier()
if args.save_interval > 0 and (
(step + 1) % (args.save_interval * args.accumulation_steps) == 0
or (step + 1) == len(dataloader)
):
save_path = os.path.join(
args.checkpoint_dir, f"epoch-{epoch}-step-{step}"
)
save_checkpoints(booster, model, opt, ema, save_path, coordinator)
logger.info(f"Saved checkpoint to {save_path}", ranks=[0])
get_accelerator().empty_cache()
final_save_path = os.path.join(args.checkpoint_dir, "final")
save_checkpoints(booster, model, opt, ema, final_save_path, coordinator)
logger.info(f"Saved checkpoint to {final_save_path}", ranks=[0])
logger.info(
f"Training complete, max device memory: {get_accelerator().max_memory_allocated() / 1024 ** 2:.2f} MB",
ranks=[0],
@ -202,7 +226,9 @@ def main(args):
if __name__ == "__main__":
# Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8")
parser.add_argument(
"-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
)
parser.add_argument("-d", "--dataset", nargs="+", default=[])
parser.add_argument("-v", "--video_dir", type=str, required=True)
parser.add_argument("-e", "--epochs", type=int, default=10)
@ -214,5 +240,7 @@ if __name__ == "__main__":
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
parser.add_argument("--tensorboard_dir", type=str, default="runs")
parser.add_argument("--vqvae", default="hpcai-tech/vqvae")
parser.add_argument("--load_model", default=None)
parser.add_argument("--load_optimizer", default=None)
args = parser.parse_args()
main(args)