This commit is contained in:
Shen-Chenhui 2024-04-08 16:05:15 +08:00
parent 0a5c767fc7
commit ab1970ca24
2 changed files with 175 additions and 135 deletions

View file

@ -20,6 +20,12 @@ from opensora.acceleration.parallel_states import (
from tqdm import tqdm
from opensora.models.vae.model_utils import VEA3DLoss
# DEBUG
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
def main():
# ======================================================
@ -49,6 +55,31 @@ def main():
set_random_seed(seed=cfg.seed)
# 2.3 DEBUG: USE BOOSTER
# 2.3. initialize ColossalAI booster
if cfg.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_data_parallel_group(dist.group.WORLD)
elif cfg.plugin == "zero2-seq":
plugin = ZeroSeqParallelPlugin(
sp_size=cfg.sp_size,
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_sequence_parallel_group(plugin.sp_group)
set_data_parallel_group(plugin.dp_group)
else:
raise ValueError(f"Unknown plugin {cfg.plugin}")
booster = Booster(plugin=plugin)
# ======================================================
# 3. build dataset and dataloader
# ======================================================
@ -108,6 +139,15 @@ def main():
os.makedirs(save_dir, exist_ok=True)
### TODO: DEBUG, USE booster
torch.set_default_dtype(dtype)
vae, _, _, dataloader, _ = booster.boost(
model=vae, dataloader=dataloader
)
# load model using booster
booster.load_model(vae, os.path.join(cfg.ckpt_path, "model"))
# 4.1. batch generation
# define loss function

View file

@ -186,151 +186,151 @@ def main():
dataloader.sampler.set_start_index(sampler_start_idx)
# # define loss function
# loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
# # 6.2. training loop
# for epoch in range(start_epoch, cfg.epochs):
# dataloader.sampler.set_epoch(epoch)
# dataloader_iter = iter(dataloader)
# logger.info(f"Beginning epoch {epoch}...")
# with tqdm(
# range(start_step, num_steps_per_epoch),
# desc=f"Epoch {epoch}",
# disable=not coordinator.is_master(),
# total=num_steps_per_epoch,
# initial=start_step,
# ) as pbar:
# for step in pbar:
# batch = next(dataloader_iter)
# x = batch["video"].to(device, dtype) # [B, C, T, H, W]
# # loss = vae.get_loss(x)
# reconstructions, posterior = vae(x)
# loss = loss_function(x, reconstructions, posterior)
# # Backward & update
# booster.backward(loss=loss, optimizer=optimizer)
# optimizer.step()
# optimizer.zero_grad()
# # Log loss values:
# all_reduce_mean(loss)
# running_loss += loss.item()
# global_step = epoch * num_steps_per_epoch + step
# log_step += 1
# # Log to tensorboard
# if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
# avg_loss = running_loss / log_step
# pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
# running_loss = 0
# log_step = 0
# writer.add_scalar("loss", loss.item(), global_step)
# if cfg.wandb:
# wandb.log(
# {
# "iter": global_step,
# "num_samples": global_step * total_batch_size,
# "epoch": epoch,
# "loss": loss.item(),
# "avg_loss": avg_loss,
# },
# step=global_step,
# )
# # Save checkpoint
# if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
# save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
# os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
# # TODO: save in model?
# booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
# booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
# if lr_scheduler is not None:
# booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
# running_states = {
# "epoch": epoch,
# "step": step+1,
# "global_step": global_step+1,
# "sample_start_index": (step+1) * cfg.batch_size,
# }
# if coordinator.is_master():
# save_json(running_states, os.path.join(save_dir, "running_states.json"))
# dist.barrier()
# logger.info(
# f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
# )
# # the continue epochs are not resumed, so we need to reset the sampler start index and start step
# dataloader.sampler.set_start_index(0)
# start_step = 0
# DEBUG inference
# 4.1. batch generation
# define loss function
loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
running_loss = 0.0
loss_steps = 0
from opensora.datasets import save_sample
# get data again
print("loading test data...")
dataset = DatasetFromCSV(
cfg.data_path,
# TODO: change transforms
transform=(
get_transforms_video(cfg.image_size[0])
if not cfg.use_image_transform
else get_transforms_image(cfg.image_size[0])
),
num_frames=cfg.num_frames,
frame_interval=cfg.frame_interval,
root=cfg.root,
)
# 6.2. training loop
for epoch in range(start_epoch, cfg.epochs):
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info(f"Beginning epoch {epoch}...")
dataloader = prepare_dataloader(
dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
shuffle=False,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
print(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})")
with tqdm(
range(start_step, num_steps_per_epoch),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
initial=start_step,
) as pbar:
for step in pbar:
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
total_steps = len(dataloader)
dataloader_iter = iter(dataloader)
# loss = vae.get_loss(x)
reconstructions, posterior = vae(x)
loss = loss_function(x, reconstructions, posterior)
print("total steps:", total_steps)
# Backward & update
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
with tqdm(
range(total_steps),
# desc=f"Avg Loss: {running_loss}",
disable=not coordinator.is_master(),
total=total_steps,
initial=0,
) as pbar:
for step in pbar:
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
reconstructions, posterior = vae(x)
loss = loss_function(x, reconstructions, posterior)
loss_steps += 1
running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
# Log loss values:
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
if coordinator.is_master():
for idx, sample in enumerate(reconstructions):
pos = step * cfg.batch_size + idx
save_path = os.path.join("outputs/debug", f"sample_{pos}")
save_sample(sample, fps=8, save_path=save_path)
# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
running_loss = 0
log_step = 0
writer.add_scalar("loss", loss.item(), global_step)
if cfg.wandb:
wandb.log(
{
"iter": global_step,
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
},
step=global_step,
)
print("test loss:", running_loss)
# Save checkpoint
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
# TODO: save in model?
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step+1,
"global_step": global_step+1,
"sample_start_index": (step+1) * cfg.batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
dist.barrier()
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
)
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
start_step = 0
# # DEBUG inference
# # 4.1. batch generation
# # define loss function
# loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
# running_loss = 0.0
# loss_steps = 0
# from opensora.datasets import save_sample
# # get data again
# print("loading test data...")
# dataset = DatasetFromCSV(
# cfg.data_path,
# # TODO: change transforms
# transform=(
# get_transforms_video(cfg.image_size[0])
# if not cfg.use_image_transform
# else get_transforms_image(cfg.image_size[0])
# ),
# num_frames=cfg.num_frames,
# frame_interval=cfg.frame_interval,
# root=cfg.root,
# )
# dataloader = prepare_dataloader(
# dataset,
# batch_size=cfg.batch_size,
# num_workers=cfg.num_workers,
# shuffle=False,
# drop_last=True,
# pin_memory=True,
# process_group=get_data_parallel_group(),
# )
# print(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})")
# total_steps = len(dataloader)
# dataloader_iter = iter(dataloader)
# print("total steps:", total_steps)
# with tqdm(
# range(total_steps),
# # desc=f"Avg Loss: {running_loss}",
# disable=not coordinator.is_master(),
# total=total_steps,
# initial=0,
# ) as pbar:
# for step in pbar:
# batch = next(dataloader_iter)
# x = batch["video"].to(device, dtype) # [B, C, T, H, W]
# reconstructions, posterior = vae(x)
# loss = loss_function(x, reconstructions, posterior)
# loss_steps += 1
# running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
# # if coordinator.is_master():
# # for idx, sample in enumerate(reconstructions):
# # pos = step * cfg.batch_size + idx
# # save_path = os.path.join("outputs/debug", f"sample_{pos}")
# # save_sample(sample, fps=8, save_path=save_path)
# print("test loss:", running_loss)
if __name__ == "__main__":
main()