mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-21 04:27:32 +02:00
debug
This commit is contained in:
parent
0a5c767fc7
commit
ab1970ca24
|
|
@ -20,6 +20,12 @@ from opensora.acceleration.parallel_states import (
|
|||
from tqdm import tqdm
|
||||
from opensora.models.vae.model_utils import VEA3DLoss
|
||||
|
||||
# DEBUG
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
# ======================================================
|
||||
|
|
@ -49,6 +55,31 @@ def main():
|
|||
set_random_seed(seed=cfg.seed)
|
||||
|
||||
|
||||
# 2.3 DEBUG: USE BOOSTER
|
||||
# 2.3. initialize ColossalAI booster
|
||||
if cfg.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=cfg.dtype,
|
||||
initial_scale=2**16,
|
||||
max_norm=cfg.grad_clip,
|
||||
)
|
||||
set_data_parallel_group(dist.group.WORLD)
|
||||
elif cfg.plugin == "zero2-seq":
|
||||
plugin = ZeroSeqParallelPlugin(
|
||||
sp_size=cfg.sp_size,
|
||||
stage=2,
|
||||
precision=cfg.dtype,
|
||||
initial_scale=2**16,
|
||||
max_norm=cfg.grad_clip,
|
||||
)
|
||||
set_sequence_parallel_group(plugin.sp_group)
|
||||
set_data_parallel_group(plugin.dp_group)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {cfg.plugin}")
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
|
||||
# ======================================================
|
||||
# 3. build dataset and dataloader
|
||||
# ======================================================
|
||||
|
|
@ -108,6 +139,15 @@ def main():
|
|||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
|
||||
### TODO: DEBUG, USE booster
|
||||
torch.set_default_dtype(dtype)
|
||||
vae, _, _, dataloader, _ = booster.boost(
|
||||
model=vae, dataloader=dataloader
|
||||
)
|
||||
# load model using booster
|
||||
booster.load_model(vae, os.path.join(cfg.ckpt_path, "model"))
|
||||
|
||||
|
||||
# 4.1. batch generation
|
||||
|
||||
# define loss function
|
||||
|
|
|
|||
|
|
@ -186,151 +186,151 @@ def main():
|
|||
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
|
||||
# # define loss function
|
||||
# loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
|
||||
|
||||
|
||||
# # 6.2. training loop
|
||||
# for epoch in range(start_epoch, cfg.epochs):
|
||||
# dataloader.sampler.set_epoch(epoch)
|
||||
# dataloader_iter = iter(dataloader)
|
||||
# logger.info(f"Beginning epoch {epoch}...")
|
||||
|
||||
# with tqdm(
|
||||
# range(start_step, num_steps_per_epoch),
|
||||
# desc=f"Epoch {epoch}",
|
||||
# disable=not coordinator.is_master(),
|
||||
# total=num_steps_per_epoch,
|
||||
# initial=start_step,
|
||||
# ) as pbar:
|
||||
# for step in pbar:
|
||||
# batch = next(dataloader_iter)
|
||||
# x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
|
||||
# # loss = vae.get_loss(x)
|
||||
# reconstructions, posterior = vae(x)
|
||||
# loss = loss_function(x, reconstructions, posterior)
|
||||
|
||||
# # Backward & update
|
||||
# booster.backward(loss=loss, optimizer=optimizer)
|
||||
# optimizer.step()
|
||||
# optimizer.zero_grad()
|
||||
|
||||
# # Log loss values:
|
||||
# all_reduce_mean(loss)
|
||||
# running_loss += loss.item()
|
||||
# global_step = epoch * num_steps_per_epoch + step
|
||||
# log_step += 1
|
||||
|
||||
# # Log to tensorboard
|
||||
# if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
# avg_loss = running_loss / log_step
|
||||
# pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
|
||||
# running_loss = 0
|
||||
# log_step = 0
|
||||
# writer.add_scalar("loss", loss.item(), global_step)
|
||||
# if cfg.wandb:
|
||||
# wandb.log(
|
||||
# {
|
||||
# "iter": global_step,
|
||||
# "num_samples": global_step * total_batch_size,
|
||||
# "epoch": epoch,
|
||||
# "loss": loss.item(),
|
||||
# "avg_loss": avg_loss,
|
||||
# },
|
||||
# step=global_step,
|
||||
# )
|
||||
|
||||
# # Save checkpoint
|
||||
# if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
|
||||
# save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
# os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
# # TODO: save in model?
|
||||
# booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
|
||||
# booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
# if lr_scheduler is not None:
|
||||
# booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
# running_states = {
|
||||
# "epoch": epoch,
|
||||
# "step": step+1,
|
||||
# "global_step": global_step+1,
|
||||
# "sample_start_index": (step+1) * cfg.batch_size,
|
||||
# }
|
||||
# if coordinator.is_master():
|
||||
# save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
# dist.barrier()
|
||||
# logger.info(
|
||||
# f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
# )
|
||||
|
||||
# # the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
# dataloader.sampler.set_start_index(0)
|
||||
# start_step = 0
|
||||
|
||||
# DEBUG inference
|
||||
|
||||
# 4.1. batch generation
|
||||
|
||||
# define loss function
|
||||
loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
|
||||
running_loss = 0.0
|
||||
loss_steps = 0
|
||||
|
||||
from opensora.datasets import save_sample
|
||||
|
||||
# get data again
|
||||
print("loading test data...")
|
||||
dataset = DatasetFromCSV(
|
||||
cfg.data_path,
|
||||
# TODO: change transforms
|
||||
transform=(
|
||||
get_transforms_video(cfg.image_size[0])
|
||||
if not cfg.use_image_transform
|
||||
else get_transforms_image(cfg.image_size[0])
|
||||
),
|
||||
num_frames=cfg.num_frames,
|
||||
frame_interval=cfg.frame_interval,
|
||||
root=cfg.root,
|
||||
)
|
||||
# 6.2. training loop
|
||||
for epoch in range(start_epoch, cfg.epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
logger.info(f"Beginning epoch {epoch}...")
|
||||
|
||||
dataloader = prepare_dataloader(
|
||||
dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
print(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})")
|
||||
with tqdm(
|
||||
range(start_step, num_steps_per_epoch),
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not coordinator.is_master(),
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
|
||||
total_steps = len(dataloader)
|
||||
dataloader_iter = iter(dataloader)
|
||||
# loss = vae.get_loss(x)
|
||||
reconstructions, posterior = vae(x)
|
||||
loss = loss_function(x, reconstructions, posterior)
|
||||
|
||||
print("total steps:", total_steps)
|
||||
# Backward & update
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
with tqdm(
|
||||
range(total_steps),
|
||||
# desc=f"Avg Loss: {running_loss}",
|
||||
disable=not coordinator.is_master(),
|
||||
total=total_steps,
|
||||
initial=0,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
reconstructions, posterior = vae(x)
|
||||
loss = loss_function(x, reconstructions, posterior)
|
||||
loss_steps += 1
|
||||
running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
|
||||
# Log loss values:
|
||||
all_reduce_mean(loss)
|
||||
running_loss += loss.item()
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
log_step += 1
|
||||
|
||||
if coordinator.is_master():
|
||||
for idx, sample in enumerate(reconstructions):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join("outputs/debug", f"sample_{pos}")
|
||||
save_sample(sample, fps=8, save_path=save_path)
|
||||
# Log to tensorboard
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
avg_loss = running_loss / log_step
|
||||
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
|
||||
running_loss = 0
|
||||
log_step = 0
|
||||
writer.add_scalar("loss", loss.item(), global_step)
|
||||
if cfg.wandb:
|
||||
wandb.log(
|
||||
{
|
||||
"iter": global_step,
|
||||
"num_samples": global_step * total_batch_size,
|
||||
"epoch": epoch,
|
||||
"loss": loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
print("test loss:", running_loss)
|
||||
# Save checkpoint
|
||||
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
|
||||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
# TODO: save in model?
|
||||
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step+1,
|
||||
"global_step": global_step+1,
|
||||
"sample_start_index": (step+1) * cfg.batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
dist.barrier()
|
||||
logger.info(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
)
|
||||
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
start_step = 0
|
||||
|
||||
# # DEBUG inference
|
||||
|
||||
# # 4.1. batch generation
|
||||
|
||||
# # define loss function
|
||||
# loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
|
||||
# running_loss = 0.0
|
||||
# loss_steps = 0
|
||||
|
||||
# from opensora.datasets import save_sample
|
||||
|
||||
# # get data again
|
||||
# print("loading test data...")
|
||||
# dataset = DatasetFromCSV(
|
||||
# cfg.data_path,
|
||||
# # TODO: change transforms
|
||||
# transform=(
|
||||
# get_transforms_video(cfg.image_size[0])
|
||||
# if not cfg.use_image_transform
|
||||
# else get_transforms_image(cfg.image_size[0])
|
||||
# ),
|
||||
# num_frames=cfg.num_frames,
|
||||
# frame_interval=cfg.frame_interval,
|
||||
# root=cfg.root,
|
||||
# )
|
||||
|
||||
# dataloader = prepare_dataloader(
|
||||
# dataset,
|
||||
# batch_size=cfg.batch_size,
|
||||
# num_workers=cfg.num_workers,
|
||||
# shuffle=False,
|
||||
# drop_last=True,
|
||||
# pin_memory=True,
|
||||
# process_group=get_data_parallel_group(),
|
||||
# )
|
||||
# print(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})")
|
||||
|
||||
# total_steps = len(dataloader)
|
||||
# dataloader_iter = iter(dataloader)
|
||||
|
||||
# print("total steps:", total_steps)
|
||||
|
||||
# with tqdm(
|
||||
# range(total_steps),
|
||||
# # desc=f"Avg Loss: {running_loss}",
|
||||
# disable=not coordinator.is_master(),
|
||||
# total=total_steps,
|
||||
# initial=0,
|
||||
# ) as pbar:
|
||||
# for step in pbar:
|
||||
# batch = next(dataloader_iter)
|
||||
# x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
# reconstructions, posterior = vae(x)
|
||||
# loss = loss_function(x, reconstructions, posterior)
|
||||
# loss_steps += 1
|
||||
# running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
|
||||
|
||||
# # if coordinator.is_master():
|
||||
# # for idx, sample in enumerate(reconstructions):
|
||||
# # pos = step * cfg.batch_size + idx
|
||||
# # save_path = os.path.join("outputs/debug", f"sample_{pos}")
|
||||
# # save_sample(sample, fps=8, save_path=save_path)
|
||||
|
||||
# print("test loss:", running_loss)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue