mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-18 16:49:41 +02:00
[hotfix] fix sample and update training script (#15)
* [hotfix] fix sample * [hotfix] fix sample
This commit is contained in:
parent
14db4566e1
commit
e02f6286de
26
sample.py
26
sample.py
|
|
@ -19,7 +19,7 @@ from transformers import AutoModel, AutoTokenizer, CLIPTextModel
|
|||
|
||||
from open_sora.diffusion import create_diffusion
|
||||
from open_sora.modeling import DiT_models
|
||||
from open_sora.utils.data import col2video
|
||||
from open_sora.utils.data import col2video, unnormalize_video
|
||||
|
||||
|
||||
def main(args):
|
||||
|
|
@ -34,16 +34,20 @@ def main(args):
|
|||
.eval()
|
||||
)
|
||||
in_channels = vqvae.embedding_dim
|
||||
w_h_factor = 4
|
||||
t_factor = 2
|
||||
else:
|
||||
# disable VQ-VAE if not provided, just use raw video frames
|
||||
vqvae = None
|
||||
in_channels = 3
|
||||
w_h_factor = 1
|
||||
t_factor = 1
|
||||
text_model = CLIPTextModel.from_pretrained(args.text_model).to(device).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
|
||||
|
||||
model = DiT_models[args.model](in_channels=in_channels).to(device).eval()
|
||||
patch_size = model.patch_size
|
||||
# model.load_state_dict(torch.load(args.ckpt))
|
||||
model.load_state_dict(torch.load(args.ckpt))
|
||||
diffusion = create_diffusion(str(args.num_sampling_steps))
|
||||
|
||||
# Create sampling noise:
|
||||
|
|
@ -54,9 +58,9 @@ def main(args):
|
|||
num_frames = args.fps * args.sec
|
||||
z = torch.randn(
|
||||
1,
|
||||
(args.height // patch_size // 4)
|
||||
* (args.width // patch_size // 4)
|
||||
* (num_frames // 2),
|
||||
(args.height // patch_size // w_h_factor)
|
||||
* (args.width // patch_size // w_h_factor)
|
||||
* (num_frames // t_factor),
|
||||
in_channels,
|
||||
patch_size,
|
||||
patch_size,
|
||||
|
|
@ -87,7 +91,12 @@ def main(args):
|
|||
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
|
||||
samples = col2video(
|
||||
samples.squeeze(),
|
||||
(num_frames // 2, in_channels, args.height // 4, args.width // 4),
|
||||
(
|
||||
num_frames // t_factor,
|
||||
in_channels,
|
||||
args.height // w_h_factor,
|
||||
args.width // w_h_factor,
|
||||
),
|
||||
)
|
||||
if vqvae is not None:
|
||||
# [T, C, H, W] -> [B, C, T, H, W]
|
||||
|
|
@ -98,6 +107,7 @@ def main(args):
|
|||
else:
|
||||
# [T, C, H, W] -> [T, H, W, C]
|
||||
samples = samples.permute(0, 2, 3, 1)
|
||||
samples = unnormalize_video(samples).to(torch.uint8)
|
||||
|
||||
write_video("sample.mp4", samples.cpu(), args.fps)
|
||||
|
||||
|
|
@ -105,12 +115,12 @@ def main(args):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
|
||||
"-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
default="two ladies laughing by seeing some thing another lady throw dresses and keep it back by reverse motion",
|
||||
default="a cartoon animals runs through an ice cave in a video game",
|
||||
)
|
||||
parser.add_argument("--cfg-scale", type=float, default=4.0)
|
||||
parser.add_argument("--num-sampling-steps", type=int, default=250)
|
||||
|
|
|
|||
64
train.py
64
train.py
|
|
@ -64,6 +64,18 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
|||
return tensor
|
||||
|
||||
|
||||
def save_checkpoints(booster, model, optimizer, ema, save_path, coordinator):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
booster.save_model(model, os.path.join(save_path, "model"), shard=False)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_path, "optimizer"), shard=True)
|
||||
if coordinator.is_master():
|
||||
ema_state_dict = ema.state_dict()
|
||||
for k, v in ema_state_dict.items():
|
||||
ema_state_dict[k] = v.cpu()
|
||||
torch.save(ema_state_dict, os.path.join(save_path, "ema.pt"))
|
||||
dist.barrier()
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Training Loop #
|
||||
#################################################################################
|
||||
|
|
@ -89,7 +101,11 @@ def main(args):
|
|||
|
||||
# Step 3: Create VQ-VAE
|
||||
if len(args.vqvae) > 0:
|
||||
vqvae = AutoModel.from_pretrained(args.vqvae, trust_remote_code=True).to(get_current_device()).eval()
|
||||
vqvae = (
|
||||
AutoModel.from_pretrained(args.vqvae, trust_remote_code=True)
|
||||
.to(get_current_device())
|
||||
.eval()
|
||||
)
|
||||
model_kwargs = {"in_channels": vqvae.embedding_dim}
|
||||
else:
|
||||
# disable VQ-VAE if not provided, just use raw video frames
|
||||
|
|
@ -110,7 +126,9 @@ def main(args):
|
|||
model.enable_gradient_checkpointing()
|
||||
|
||||
# Step 5: create diffusion pipeline
|
||||
diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
|
||||
diffusion = create_diffusion(
|
||||
timestep_respacing=""
|
||||
) # default: 1000 steps, linear noise schedule
|
||||
|
||||
# Step 6: setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0)
|
||||
|
|
@ -128,6 +146,10 @@ def main(args):
|
|||
|
||||
# Step 8: setup booster
|
||||
model, opt, _, dataloader, _ = booster.boost(model, opt, dataloader=dataloader)
|
||||
if args.load_model is not None:
|
||||
booster.load_model(model, args.load_model)
|
||||
if args.load_optimizer is not None:
|
||||
booster.load_optimizer(opt, args.load_optimizer)
|
||||
logger.info(
|
||||
f"Booster init max device memory: {get_accelerator().max_memory_allocated() / 1024 ** 2:.2f} MB",
|
||||
ranks=[0],
|
||||
|
|
@ -154,7 +176,9 @@ def main(args):
|
|||
(video_inputs.shape[0],),
|
||||
device=video_inputs.device,
|
||||
)
|
||||
loss_dict = diffusion.training_losses(model, video_inputs, t, batch, mask=mask)
|
||||
loss_dict = diffusion.training_losses(
|
||||
model, video_inputs, t, batch, mask=mask
|
||||
)
|
||||
loss = loss_dict["loss"].mean() / args.accumulation_steps
|
||||
total_loss.add_(loss.data)
|
||||
booster.backward(loss, opt)
|
||||
|
|
@ -167,7 +191,9 @@ def main(args):
|
|||
all_reduce_mean(total_loss)
|
||||
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
||||
if coordinator.is_master():
|
||||
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||
global_step = (epoch * num_steps_per_epoch) + (
|
||||
step + 1
|
||||
) // args.accumulation_steps
|
||||
writer.add_scalar(
|
||||
tag="Loss",
|
||||
scalar_value=total_loss.item(),
|
||||
|
|
@ -177,22 +203,20 @@ def main(args):
|
|||
total_loss.zero_()
|
||||
|
||||
# Save DiT checkpoint:
|
||||
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
|
||||
step + 1
|
||||
) == len(dataloader):
|
||||
save_path = os.path.join(args.checkpoint_dir, f"epoch-{epoch}-step-{step}")
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
booster.save_model(model, os.path.join(save_path, "model"), shard=True)
|
||||
booster.save_optimizer(opt, os.path.join(save_path, "optimizer"), shard=True)
|
||||
if coordinator.is_master():
|
||||
ema_state_dict = ema.state_dict()
|
||||
for k, v in ema_state_dict.items():
|
||||
ema_state_dict[k] = v.cpu()
|
||||
torch.save(ema_state_dict, os.path.join(save_path, "ema.pt"))
|
||||
dist.barrier()
|
||||
if args.save_interval > 0 and (
|
||||
(step + 1) % (args.save_interval * args.accumulation_steps) == 0
|
||||
or (step + 1) == len(dataloader)
|
||||
):
|
||||
save_path = os.path.join(
|
||||
args.checkpoint_dir, f"epoch-{epoch}-step-{step}"
|
||||
)
|
||||
save_checkpoints(booster, model, opt, ema, save_path, coordinator)
|
||||
logger.info(f"Saved checkpoint to {save_path}", ranks=[0])
|
||||
|
||||
get_accelerator().empty_cache()
|
||||
final_save_path = os.path.join(args.checkpoint_dir, "final")
|
||||
save_checkpoints(booster, model, opt, ema, final_save_path, coordinator)
|
||||
logger.info(f"Saved checkpoint to {final_save_path}", ranks=[0])
|
||||
logger.info(
|
||||
f"Training complete, max device memory: {get_accelerator().max_memory_allocated() / 1024 ** 2:.2f} MB",
|
||||
ranks=[0],
|
||||
|
|
@ -202,7 +226,9 @@ def main(args):
|
|||
if __name__ == "__main__":
|
||||
# Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8")
|
||||
parser.add_argument(
|
||||
"-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
|
||||
)
|
||||
parser.add_argument("-d", "--dataset", nargs="+", default=[])
|
||||
parser.add_argument("-v", "--video_dir", type=str, required=True)
|
||||
parser.add_argument("-e", "--epochs", type=int, default=10)
|
||||
|
|
@ -214,5 +240,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
|
||||
parser.add_argument("--tensorboard_dir", type=str, default="runs")
|
||||
parser.add_argument("--vqvae", default="hpcai-tech/vqvae")
|
||||
parser.add_argument("--load_model", default=None)
|
||||
parser.add_argument("--load_optimizer", default=None)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
|
|||
Loading…
Reference in a new issue