mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-18 16:49:41 +02:00
Merge branch 'dev/v1.2' of github.com:hpcaitech/Open-Sora-dev into dev/v1.2
This commit is contained in:
commit
771100ed88
68
configs/opensora-v1-2/train/train_load_batch.py
Normal file
68
configs/opensora-v1-2/train/train_load_batch.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
# Dataset settings
|
||||
dataset = dict(
|
||||
type="BatchDataset",
|
||||
)
|
||||
|
||||
grad_checkpoint = True
|
||||
|
||||
# Acceleration settings
|
||||
num_workers = 8
|
||||
dtype = "bf16"
|
||||
plugin = "zero2"
|
||||
|
||||
# Model settings
|
||||
model = dict(
|
||||
type="STDiT3-XL/2",
|
||||
from_pretrained=None,
|
||||
qk_norm=True,
|
||||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
freeze_y_embedder=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="OpenSoraVAE_V1_2",
|
||||
from_pretrained="pretrained_models/vae-pipeline",
|
||||
micro_frame_size=17,
|
||||
micro_batch_size=4,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=300,
|
||||
shardformer=True,
|
||||
local_files_only=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="rflow",
|
||||
use_timestep_transform=True,
|
||||
sample_method="logit-normal",
|
||||
)
|
||||
|
||||
# Mask settings
|
||||
mask_ratios = {
|
||||
"random": 0.2,
|
||||
"intepolate": 0.01,
|
||||
"quarter_random": 0.01,
|
||||
"quarter_head": 0.01,
|
||||
"quarter_tail": 0.01,
|
||||
"quarter_head_tail": 0.01,
|
||||
"image_random": 0.05,
|
||||
"image_head": 0.1,
|
||||
"image_tail": 0.05,
|
||||
"image_head_tail": 0.05,
|
||||
}
|
||||
|
||||
# Log settings
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
epochs = 1
|
||||
log_every = 10
|
||||
ckpt_every = 500
|
||||
|
||||
# optimization settings
|
||||
load = None
|
||||
grad_clip = 1.0
|
||||
lr = 2e-4
|
||||
ema_decay = 0.99
|
||||
adam_eps = 1e-15
|
||||
|
|
@ -1,2 +1,2 @@
|
|||
from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset
|
||||
from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset, BatchDataset
|
||||
from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from torch.distributed.distributed_c10d import _get_default_group
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from .datasets import VariableVideoTextDataset, VideoTextDataset
|
||||
from .sampler import StatefulDistributedSampler, VariableVideoBatchSampler
|
||||
from .sampler import StatefulDistributedSampler, VariableVideoBatchSampler, BatchDistributedSampler
|
||||
|
||||
|
||||
# Deterministic dataloader
|
||||
|
|
@ -82,3 +82,95 @@ def prepare_dataloader(
|
|||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|
||||
|
||||
def prepare_variable_dataloader(
|
||||
dataset,
|
||||
batch_size,
|
||||
bucket_config,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group=None,
|
||||
num_bucket_build_workers=1,
|
||||
**kwargs,
|
||||
):
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
batch_sampler = VariableVideoBatchSampler(
|
||||
dataset,
|
||||
bucket_config,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
verbose=True,
|
||||
num_bucket_build_workers=num_bucket_build_workers,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def build_batch_dataloader(
|
||||
dataset,
|
||||
# batch_size=1,
|
||||
# shuffle=False,
|
||||
seed=1024,
|
||||
# drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
distributed=True,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `BatchDistributedSampler`.
|
||||
|
||||
batch_size must be 1; shuffle is not supported so far
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
if distributed:
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = BatchDistributedSampler(
|
||||
dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
sampler = None
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
# drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
|
@ -190,3 +190,50 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
# we return None here in case of errorneous data
|
||||
# the collate function will handle it
|
||||
return None
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BatchDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
The dataset is composed of multiple .bin files.
|
||||
Each .bin file is a list of batch data (like a buffer). All .bin files have the same length.
|
||||
In each training iteration, one batch is fetched from the current buffer.
|
||||
Once a buffer is consumed, load another one.
|
||||
Avoid loading the same .bin on two difference GPUs, i.e., one .bin is assigned to one GPU only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# self.meta = read_file(data_path)
|
||||
# self.path_list = self.meta['path'].tolist()
|
||||
self.path_list = [f'/mnt/nfs-207/sora_data/webvid-10M/feat_text/data/{idx}.bin' for idx in range(5)]
|
||||
|
||||
self._len_buffer = len(torch.load(self.path_list[0]))
|
||||
self._num_buffers = len(self.path_list)
|
||||
self.num_samples = self.len_buffer * len(self.path_list)
|
||||
|
||||
self.cur_file_idx = -1
|
||||
|
||||
@property
|
||||
def num_buffers(self):
|
||||
return self._num_buffers
|
||||
|
||||
@property
|
||||
def len_buffer(self):
|
||||
return self._len_buffer
|
||||
|
||||
def _load_buffer(self, idx):
|
||||
file_idx = idx // self.len_buffer
|
||||
if file_idx == self.cur_file_idx:
|
||||
return
|
||||
self.cur_file_idx = file_idx
|
||||
self.cur_buffer = torch.load(self.path_list[file_idx])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
self._load_buffer(idx)
|
||||
|
||||
batch = self.cur_buffer[idx % self.len_buffer] # dict; keys are {'x', 'fps'} and text related
|
||||
return batch
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import math
|
||||
import warnings
|
||||
|
||||
from collections import OrderedDict, defaultdict
|
||||
from pprint import pformat
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Dataset, DistributedSampler
|
||||
|
|
@ -283,3 +287,24 @@ class VariableVideoBatchSampler(DistributedSampler):
|
|||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
|
||||
class BatchDistributedSampler(DistributedSampler):
|
||||
"""
|
||||
Used with BatchDataset;
|
||||
Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
|
||||
| buffer {i} | buffer {i+1}
|
||||
rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9
|
||||
rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19
|
||||
rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
|
||||
"""
|
||||
def __iter__(self):
|
||||
num_buffers = self.dataset.num_buffers
|
||||
len_buffer = self.dataset.len_buffer
|
||||
num_buffers_i = num_buffers // self.num_replicas
|
||||
num_samples_i = len_buffer * num_buffers_i
|
||||
|
||||
indices_i = np.arange(num_samples_i) + self.rank * num_samples_i
|
||||
indices_i = indices_i.tolist()
|
||||
|
||||
return iter(indices_i)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import re
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
|
@ -221,3 +222,25 @@ def collate_fn_ignore_none(batch):
|
|||
# None value is returned when the get_item fails for an index
|
||||
batch = [val for val in batch if val is not None]
|
||||
return torch.utils.data.default_collate(batch)
|
||||
|
||||
|
||||
def collate_fn_batch(batch):
|
||||
"""
|
||||
Used only with BatchDistributedSampler
|
||||
"""
|
||||
res = torch.utils.data.default_collate(batch)
|
||||
|
||||
# squeeze the first dimension, which is due to torch.stack() in default_collate()
|
||||
if isinstance(res, collections.abc.Mapping):
|
||||
for k, v in res.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
res[k] = v.squeeze(0)
|
||||
elif isinstance(res, collections.abc.Sequence):
|
||||
res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res]
|
||||
elif isinstance(res, torch.Tensor):
|
||||
res = res.squeeze(0)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
return res
|
||||
|
||||
|
|
|
|||
360
scripts/train_load_batch.py
Normal file
360
scripts/train_load_batch.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device, set_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import build_batch_dataloader
|
||||
from opensora.datasets.utils import collate_fn_batch
|
||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
|
||||
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
|
||||
from opensora.utils.misc import (
|
||||
all_reduce_mean,
|
||||
create_logger,
|
||||
create_tensorboard_writer,
|
||||
format_numel_str,
|
||||
get_model_numel,
|
||||
requires_grad,
|
||||
to_torch_dtype,
|
||||
)
|
||||
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
|
||||
|
||||
DEFAULT_DATASET_NAME = "VideoTextDataset"
|
||||
|
||||
|
||||
def main():
|
||||
# ======================================================
|
||||
# 1. configs & runtime variables
|
||||
# ======================================================
|
||||
# == parse configs ==
|
||||
cfg = parse_configs(training=True)
|
||||
|
||||
# == device and dtype ==
|
||||
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
||||
cfg_dtype = cfg.get("dtype", "bf16")
|
||||
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
|
||||
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
||||
|
||||
# == colossalai init distributed training ==
|
||||
# NOTE: A very large timeout is set to avoid some processes exit early
|
||||
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
||||
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
||||
set_seed(cfg.get("seed", 1024))
|
||||
coordinator = DistCoordinator()
|
||||
device = get_current_device()
|
||||
|
||||
# == init exp_dir ==
|
||||
exp_name, exp_dir = define_experiment_workspace(cfg)
|
||||
coordinator.block_all()
|
||||
if coordinator.is_master():
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
save_training_config(cfg.to_dict(), exp_dir)
|
||||
coordinator.block_all()
|
||||
|
||||
# == init logger, tensorboard & wandb ==
|
||||
logger = create_logger(exp_dir)
|
||||
logger.info("Experiment directory created at %s", exp_dir)
|
||||
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
|
||||
if coordinator.is_master():
|
||||
tb_writer = create_tensorboard_writer(exp_dir)
|
||||
if cfg.get("wandb", False):
|
||||
wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
|
||||
|
||||
# == init ColossalAI booster ==
|
||||
plugin = create_colossalai_plugin(
|
||||
plugin=cfg.get("plugin", "zero2"),
|
||||
dtype=cfg_dtype,
|
||||
grad_clip=cfg.get("grad_clip", 0),
|
||||
sp_size=cfg.get("sp_size", 1),
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# 2. build dataset and dataloader
|
||||
# ======================================================
|
||||
logger.info("Building dataset...")
|
||||
# == build dataset ==
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
logger.info("Dataset contains %s samples.", len(dataset))
|
||||
|
||||
# == build dataloader ==
|
||||
# modify here
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
# batch_size=cfg.get("batch_size", 1),
|
||||
num_workers=cfg.get("num_workers", 4),
|
||||
seed=cfg.get("seed", 1024),
|
||||
# shuffle=True,
|
||||
# drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
collate_fn=collate_fn_batch,
|
||||
)
|
||||
dataloader = build_batch_dataloader(**dataloader_args)
|
||||
num_steps_per_epoch = len(dataset) // dist.get_world_size()
|
||||
sampler_to_io = None
|
||||
|
||||
'''
|
||||
TODO:
|
||||
- prefetch
|
||||
- collate fn
|
||||
- resume
|
||||
- sampler_to_io ?
|
||||
- remove text_encoder & caption_embedder
|
||||
- currently only support 1 epoch; every epoch is the same
|
||||
'''
|
||||
|
||||
# if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
# dataloader = prepare_dataloader(**dataloader_args)
|
||||
# total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
|
||||
# logger.info("Total batch size: %s", total_batch_size)
|
||||
# num_steps_per_epoch = len(dataloader)
|
||||
# sampler_to_io = None
|
||||
# else:
|
||||
# dataloader = prepare_variable_dataloader(
|
||||
# bucket_config=cfg.get("bucket_config", None),
|
||||
# num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
|
||||
# **dataloader_args,
|
||||
# )
|
||||
# num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
|
||||
# sampler_to_io = None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler
|
||||
|
||||
# ======================================================
|
||||
# 3. build model
|
||||
# ======================================================
|
||||
logger.info("Building models...")
|
||||
# == build text-encoder and vae ==
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
|
||||
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
|
||||
|
||||
# == build diffusion model ==
|
||||
# modify here
|
||||
# input_size = (dataset.num_frames, *dataset.image_size)
|
||||
# latent_size = vae.get_latent_size(input_size)
|
||||
latent_size = None, None, None
|
||||
model = (
|
||||
build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
input_size=latent_size,
|
||||
in_channels=vae.out_channels,
|
||||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
)
|
||||
.to(device, dtype)
|
||||
.train()
|
||||
)
|
||||
model_numel, model_numel_trainable = get_model_numel(model)
|
||||
logger.info(
|
||||
"[Diffusion] Trainable model params: %s, Total model params: %s",
|
||||
format_numel_str(model_numel_trainable),
|
||||
format_numel_str(model_numel),
|
||||
)
|
||||
|
||||
# == build ema for diffusion model ==
|
||||
ema = deepcopy(model).to(torch.float32).to(device)
|
||||
requires_grad(ema, False)
|
||||
ema_shape_dict = record_model_param_shape(ema)
|
||||
ema.eval()
|
||||
update_ema(ema, model, decay=0, sharded=False)
|
||||
|
||||
# == setup loss function, build scheduler ==
|
||||
scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
||||
|
||||
# == setup optimizer ==
|
||||
optimizer = HybridAdam(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
adamw_mode=True,
|
||||
lr=cfg.get("lr", 1e-4),
|
||||
weight_decay=cfg.get("weight_decay", 0),
|
||||
eps=cfg.get("adam_eps", 1e-8),
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
# == additional preparation ==
|
||||
if cfg.get("grad_checkpoint", False):
|
||||
set_grad_checkpoint(model)
|
||||
if cfg.get("mask_ratios", None) is not None:
|
||||
mask_generator = MaskGenerator(cfg.mask_ratios)
|
||||
|
||||
# =======================================================
|
||||
# 4. distributed training preparation with colossalai
|
||||
# =======================================================
|
||||
logger.info("Preparing for distributed training...")
|
||||
# == boosting ==
|
||||
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
|
||||
torch.set_default_dtype(dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=dataloader,
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
logger.info("Boosting model for distributed training")
|
||||
|
||||
# == global variables ==
|
||||
# modify here
|
||||
cfg_epochs = cfg.get("epochs", 1)
|
||||
assert cfg_epochs == 1
|
||||
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
|
||||
running_loss = 0.0
|
||||
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
|
||||
|
||||
# == resume ==
|
||||
if cfg.get("load", None) is not None:
|
||||
logger.info("Loading checkpoint")
|
||||
ret = load(
|
||||
booster,
|
||||
cfg.load,
|
||||
model=model,
|
||||
ema=ema,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
sampler=sampler_to_io,
|
||||
)
|
||||
if not cfg.get("start_from_scratch ", False):
|
||||
start_epoch, start_step, sampler_start_idx = ret
|
||||
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
|
||||
model_sharding(ema)
|
||||
|
||||
# =======================================================
|
||||
# 5. training loop
|
||||
# =======================================================
|
||||
dist.barrier()
|
||||
for epoch in range(start_epoch, cfg_epochs):
|
||||
# == set dataloader to new epoch ==
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
logger.info("Beginning epoch %s...", epoch)
|
||||
|
||||
# == training loop in an epoch ==
|
||||
with tqdm(
|
||||
enumerate(dataloader_iter, start=start_step),
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not coordinator.is_master(),
|
||||
initial=start_step,
|
||||
total=num_steps_per_epoch,
|
||||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
# modify here
|
||||
x = batch['x'].to(device, dtype) # feat of vae encoder
|
||||
print(step, dist.get_rank(), batch['x'].shape)
|
||||
continue
|
||||
|
||||
# x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
||||
# y = batch.pop("text")
|
||||
|
||||
# == visual and text encoding ==
|
||||
# with torch.no_grad():
|
||||
# # Prepare visual inputs
|
||||
# x = vae.encode(x) # [B, C, T, H/P, W/P]
|
||||
# # Prepare text inputs
|
||||
# model_args = text_encoder.encode(y)
|
||||
|
||||
model_args = {}
|
||||
# == mask ==
|
||||
mask = None
|
||||
if cfg.get("mask_ratios", None) is not None:
|
||||
mask = mask_generator.get_masks(x)
|
||||
model_args["x_mask"] = mask
|
||||
|
||||
# == video meta info ==
|
||||
for k, v in batch.items():
|
||||
model_args[k] = v.to(device, dtype)
|
||||
|
||||
# == diffusion loss computation ==
|
||||
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
|
||||
|
||||
# == backward & update ==
|
||||
loss = loss_dict["loss"].mean()
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# == update EMA ==
|
||||
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
|
||||
|
||||
# == update log info ==
|
||||
all_reduce_mean(loss)
|
||||
running_loss += loss.item()
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
log_step += 1
|
||||
acc_step += 1
|
||||
|
||||
# == logging ==
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
|
||||
avg_loss = running_loss / log_step
|
||||
# progress bar
|
||||
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
|
||||
# tensorboard
|
||||
tb_writer.add_scalar("loss", loss.item(), global_step)
|
||||
# wandb
|
||||
if cfg.get("wandb", False):
|
||||
wandb.log(
|
||||
{
|
||||
"iter": global_step,
|
||||
"epoch": epoch,
|
||||
"loss": loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
"acc_step": acc_step,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
running_loss = 0.0
|
||||
log_step = 0
|
||||
|
||||
# == checkpoint saving ==
|
||||
ckpt_every = cfg.get("ckpt_every", 0)
|
||||
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
|
||||
model_gathering(ema, ema_shape_dict)
|
||||
save(
|
||||
booster,
|
||||
exp_dir,
|
||||
model=model,
|
||||
ema=ema,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
sampler=sampler_to_io,
|
||||
epoch=epoch,
|
||||
step=step + 1,
|
||||
global_step=global_step + 1,
|
||||
batch_size=cfg.get("batch_size", None),
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
model_sharding(ema)
|
||||
logger.info(
|
||||
"Saved checkpoint at epoch %s step %s global_step %s to %s",
|
||||
epoch,
|
||||
step + 1,
|
||||
global_step + 1,
|
||||
exp_dir,
|
||||
)
|
||||
|
||||
# NOTE: the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_start_index(0)
|
||||
else:
|
||||
dataloader.batch_sampler.set_epoch(epoch + 1)
|
||||
logger.info("Epoch done, recomputing batch sampler")
|
||||
start_step = 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in a new issue