mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
[feat] update inference config
This commit is contained in:
parent
3a01cb440d
commit
e1e379d898
|
|
@ -1,4 +1,5 @@
|
|||
image_size = (240, 426)
|
||||
resolution = "240p"
|
||||
aspect_ratio = "9:16"
|
||||
num_frames = 51
|
||||
fps = 24
|
||||
frame_interval = 1
|
||||
|
|
@ -30,22 +31,10 @@ model = dict(
|
|||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
type="OpenSoraVAE_V1_2",
|
||||
from_pretrained="pretrained_models/vae-pipeline",
|
||||
shift=(-0.10, 0.34, 0.27, 0.98),
|
||||
scale=(3.85, 2.32, 2.33, 3.06),
|
||||
micro_frame_size=17,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
micro_batch_size=4,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
|
|
|
|||
|
|
@ -36,22 +36,10 @@ model = dict(
|
|||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
type="OpenSoraVAE_V1_2",
|
||||
from_pretrained="pretrained_models/vae-pipeline",
|
||||
shift=(-0.10, 0.34, 0.27, 0.98),
|
||||
scale=(3.85, 2.32, 2.33, 3.06),
|
||||
micro_frame_size=17,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
micro_batch_size=4,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
|
|
|
|||
|
|
@ -1,88 +0,0 @@
|
|||
# Dataset settings
|
||||
dataset = dict(
|
||||
type="VariableVideoTextDataset",
|
||||
transform_name="resize_crop",
|
||||
frame_interval=1,
|
||||
)
|
||||
|
||||
bucket_config = { # 20s/it
|
||||
"1024": {1: (1.0, 1)},
|
||||
}
|
||||
|
||||
grad_checkpoint = True
|
||||
batch_size = None
|
||||
|
||||
# Acceleration settings
|
||||
num_workers = 8
|
||||
num_bucket_build_workers = 16
|
||||
dtype = "bf16"
|
||||
plugin = "zero2"
|
||||
|
||||
# Model settings
|
||||
model = dict(
|
||||
type="STDiT3-XL/2",
|
||||
from_pretrained=None,
|
||||
qk_norm=True,
|
||||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
from_pretrained="pretrained_models/vae-pipeline",
|
||||
micro_frame_size=17,
|
||||
shift=(-0.10, 0.34, 0.27, 0.98),
|
||||
scale=(3.85, 2.32, 2.33, 3.06),
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=300,
|
||||
shardformer=True,
|
||||
local_files_only=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="rflow",
|
||||
use_discrete_timesteps=False,
|
||||
use_timestep_transform=False,
|
||||
# sample_method="logit-normal",
|
||||
)
|
||||
|
||||
# Mask settings
|
||||
# mask_ratios = {
|
||||
# "random": 0.1,
|
||||
# "intepolate": 0.01,
|
||||
# "quarter_random": 0.01,
|
||||
# "quarter_head": 0.01,
|
||||
# "quarter_tail": 0.01,
|
||||
# "quarter_head_tail": 0.01,
|
||||
# "image_random": 0.05,
|
||||
# "image_head": 0.1,
|
||||
# "image_tail": 0.05,
|
||||
# "image_head_tail": 0.05,
|
||||
# }
|
||||
|
||||
# Log settings
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 500
|
||||
|
||||
# optimization settings
|
||||
load = None
|
||||
grad_clip = 1.0
|
||||
lr = 2e-4
|
||||
ema_decay = 0.99
|
||||
adam_eps = 1e-15
|
||||
|
|
@ -61,22 +61,10 @@ model = dict(
|
|||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
type="OpenSoraVAE_V1_2",
|
||||
from_pretrained="pretrained_models/vae-pipeline",
|
||||
shift=(-0.10, 0.34, 0.27, 0.98),
|
||||
scale=(3.85, 2.32, 2.33, 3.06),
|
||||
micro_frame_size=17,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
micro_batch_size=4,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
|
|
|
|||
|
|
@ -1,44 +1,32 @@
|
|||
num_frames = 1
|
||||
frame_interval = 1
|
||||
fps = 24
|
||||
image_size = (256, 256)
|
||||
num_frames = 1
|
||||
|
||||
dtype = "bf16"
|
||||
batch_size = 1
|
||||
seed = 42
|
||||
save_dir = "samples/vae_video"
|
||||
cal_stats = True
|
||||
log_stats_every = 100
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
num_samples = 100
|
||||
num_workers = 4
|
||||
max_test_samples = None
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
freeze_vae_2d=True,
|
||||
type="OpenSoraVAE_V1_2",
|
||||
from_pretrained="pretrained_models/vae-pipeline",
|
||||
micro_frame_size=None,
|
||||
micro_batch_size=4,
|
||||
cal_loss=True,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
)
|
||||
dtype = "bf16"
|
||||
|
||||
# loss weights
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
kl_loss_weight = 1e-6
|
||||
calc_std = True
|
||||
|
||||
# Others
|
||||
batch_size = 1
|
||||
seed = 42
|
||||
save_dir = "samples/vae_image"
|
||||
|
|
|
|||
|
|
@ -1,45 +1,32 @@
|
|||
num_frames = 17
|
||||
frame_interval = 1
|
||||
fps = 24
|
||||
image_size = (256, 256)
|
||||
num_frames = 17
|
||||
|
||||
dtype = "bf16"
|
||||
batch_size = 1
|
||||
seed = 42
|
||||
save_dir = "samples/vae_video"
|
||||
cal_stats = True
|
||||
log_stats_every = 100
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
num_samples = 100
|
||||
num_workers = 4
|
||||
max_test_samples = None
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
from_pretrained=None,
|
||||
type="OpenSoraVAE_V1_2",
|
||||
from_pretrained="pretrained_models/vae-pipeline",
|
||||
micro_frame_size=None,
|
||||
micro_batch_size=4,
|
||||
cal_loss=True,
|
||||
micro_frame_size=16,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
)
|
||||
dtype = "bf16"
|
||||
|
||||
# loss weights
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
kl_loss_weight = 1e-6
|
||||
calc_std = True
|
||||
|
||||
# Others
|
||||
batch_size = 1
|
||||
seed = 42
|
||||
save_dir = "samples/vae_video"
|
||||
|
|
|
|||
0
notebooks/inference.ipynb
Normal file
0
notebooks/inference.ipynb
Normal file
|
|
@ -1,28 +1,7 @@
|
|||
import math
|
||||
|
||||
|
||||
# computation
|
||||
AR = (
|
||||
3 / 8,
|
||||
9 / 21,
|
||||
0.48,
|
||||
1 / 2,
|
||||
9 / 17,
|
||||
1 / 1.85,
|
||||
9 / 16,
|
||||
5 / 8,
|
||||
2 / 3,
|
||||
3 / 4,
|
||||
1 / 1,
|
||||
4 / 3,
|
||||
3 / 2,
|
||||
16 / 9,
|
||||
17 / 9,
|
||||
2 / 1,
|
||||
1 / 0.48,
|
||||
)
|
||||
AR_fraction = (0.375, 0.43, 0.48, 0.50, 0.53, 0.54, 0.56, 0.62, 0.67, 0.75, 1, 1.33, 1.50, 1.78, 1.89, 2, 2.08)
|
||||
|
||||
|
||||
def get_h_w(a, ts, eps=1e-4):
|
||||
h = (ts * a) ** 0.5
|
||||
h = h + eps
|
||||
|
|
@ -33,11 +12,39 @@ def get_h_w(a, ts, eps=1e-4):
|
|||
return h, w
|
||||
|
||||
|
||||
def get_aspect_ratios_dict(ts=360 * 640, ars=AR):
|
||||
def get_aspect_ratios_dict(ars, ts=360 * 640):
|
||||
est = {f"{a:.2f}": get_h_w(a, ts) for a in ars}
|
||||
return est
|
||||
|
||||
|
||||
def get_ar(ratio):
|
||||
h, w = ratio.split(":")
|
||||
return int(h) / int(w)
|
||||
|
||||
|
||||
ASPECT_RATIO_MAP = {
|
||||
"3:8": "0.38",
|
||||
"9:21": "0.43",
|
||||
"12:25": "0.48",
|
||||
"1:2": "0.50",
|
||||
"9:17": "0.53",
|
||||
"27:50": "0.54",
|
||||
"9:16": "0.56",
|
||||
"5:8": "0.62",
|
||||
"2:3": "0.67",
|
||||
"3:4": "0.75",
|
||||
"1:1": "1.00",
|
||||
"4:3": "1.33",
|
||||
"3:2": "1.50",
|
||||
"16:9": "1.78",
|
||||
"17:9": "1.89",
|
||||
"2:1": "2.00",
|
||||
"50:27": "2.08",
|
||||
}
|
||||
|
||||
|
||||
AR = [get_ar(ratio) for ratio in ASPECT_RATIO_MAP.keys()]
|
||||
|
||||
# computed from above code
|
||||
# S = 8294400
|
||||
ASPECT_RATIO_4K = {
|
||||
|
|
@ -454,3 +461,10 @@ ASPECT_RATIOS = {
|
|||
|
||||
def get_num_pixels(name):
|
||||
return ASPECT_RATIOS[name][0]
|
||||
|
||||
|
||||
def get_image_size(resolution, ar_ratio):
|
||||
ar_key = ASPECT_RATIO_MAP[ar_ratio]
|
||||
rs_dict = ASPECT_RATIOS[resolution][1]
|
||||
assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}"
|
||||
return rs_dict[ar_key]
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ def prepare_dataloader(
|
|||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
distributed=True,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
|
|
@ -43,13 +44,16 @@ def prepare_dataloader(
|
|||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
)
|
||||
if distributed:
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
|
|
|
|||
|
|
@ -220,3 +220,39 @@ class VideoAutoencoderPipeline(nn.Module):
|
|||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OpenSoraVAE_V1_2(VideoAutoencoderPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
micro_batch_size=4,
|
||||
micro_frame_size=17,
|
||||
from_pretrained=None,
|
||||
local_files_only=True,
|
||||
freeze_vae_2d=False,
|
||||
cal_loss=False,
|
||||
):
|
||||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=micro_batch_size,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
vae_temporal = dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
)
|
||||
shift = (-0.10, 0.34, 0.27, 0.98)
|
||||
scale = (3.85, 2.32, 2.33, 3.06)
|
||||
super().__init__(
|
||||
vae_2d,
|
||||
vae_temporal,
|
||||
from_pretrained,
|
||||
freeze_vae_2d=freeze_vae_2d,
|
||||
cal_loss=cal_loss,
|
||||
micro_frame_size=micro_frame_size,
|
||||
shift=shift,
|
||||
scale=scale,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,10 +24,8 @@ def parse_args(training=False):
|
|||
)
|
||||
parser.add_argument("--batch-size", default=None, type=int, help="batch size")
|
||||
parser.add_argument("--outputs", default=None, type=str, help="the dir to save model weights")
|
||||
parser.add_argument("--flash-attn", default=None, action=argparse.BooleanOptionalAction, help="enable flash attn")
|
||||
parser.add_argument(
|
||||
"--layernorm-kernel", default=None, action=argparse.BooleanOptionalAction, help="enable layernorm kernel"
|
||||
)
|
||||
parser.add_argument("--flash-attn", default=None, type=str2bool, help="enable flash attention")
|
||||
parser.add_argument("--layernorm-kernel", default=None, type=str2bool, help="enable layernorm kernel")
|
||||
|
||||
# ======================================================
|
||||
# Inference
|
||||
|
|
@ -51,6 +49,9 @@ def parse_args(training=False):
|
|||
parser.add_argument("--num-frames", default=None, type=int, help="number of frames")
|
||||
parser.add_argument("--fps", default=None, type=int, help="fps")
|
||||
parser.add_argument("--image-size", default=None, type=int, nargs=2, help="image size")
|
||||
parser.add_argument("--frame-interval", default=None, type=int, help="frame interval")
|
||||
parser.add_argument("--resolution", default=None, type=str, help="multi resolution")
|
||||
parser.add_argument("--aspect-ratio", default=None, type=float, help="aspect ratio")
|
||||
|
||||
# hyperparameters
|
||||
parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps")
|
||||
|
|
@ -143,3 +144,14 @@ def define_experiment_workspace(cfg, get_last_workspace=False):
|
|||
def save_training_config(cfg, experiment_dir):
|
||||
with open(f"{experiment_dir}/config.txt", "w") as f:
|
||||
json.dump(cfg, f, indent=4)
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
|
|
|||
|
|
@ -132,7 +132,9 @@ def find_nearest_point(value, point, max_value):
|
|||
|
||||
def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
|
||||
masks = []
|
||||
no_mask = True
|
||||
for i, mask_strategy in enumerate(mask_strategys):
|
||||
no_mask = False
|
||||
mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
|
||||
mask_strategy = parse_mask_strategy(mask_strategy)
|
||||
for mst in mask_strategy:
|
||||
|
|
@ -154,6 +156,8 @@ def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
|
|||
z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
|
||||
mask[m_target_start : m_target_start + m_length] = edit_ratio
|
||||
masks.append(mask)
|
||||
if no_mask:
|
||||
return None
|
||||
masks = torch.stack(masks)
|
||||
return masks
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,13 @@ def is_main_process():
|
|||
return not is_distributed() or dist.get_rank() == 0
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if is_distributed():
|
||||
return dist.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def create_logger(logging_dir=None):
|
||||
"""
|
||||
Create a logger that writes to a log file and stdout.
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from tqdm import tqdm
|
|||
|
||||
from opensora.acceleration.parallel_states import set_sequence_parallel_group
|
||||
from opensora.datasets import save_sample
|
||||
from opensora.datasets.aspect import get_image_size
|
||||
from opensora.models.text_encoder.t5 import text_preprocessing
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
|
|
@ -52,7 +53,7 @@ def main():
|
|||
else:
|
||||
coordinator = None
|
||||
enable_sequence_parallelism = False
|
||||
set_random_seed(seed=cfg.seed)
|
||||
set_random_seed(seed=cfg.get("seed", 1024))
|
||||
|
||||
# == init logger ==
|
||||
logger = create_logger()
|
||||
|
|
@ -68,8 +69,19 @@ def main():
|
|||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
|
||||
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
|
||||
|
||||
# == prepare video size ==
|
||||
image_size = cfg.get("image_size", None)
|
||||
if image_size is None:
|
||||
resolution = cfg.get("resolution", None)
|
||||
aspect_ratio = cfg.get("aspect_ratio", None)
|
||||
assert (
|
||||
resolution is not None and aspect_ratio is not None
|
||||
), "resolution and aspect_ratio must be provided if image_size is not provided"
|
||||
image_size = get_image_size(resolution, aspect_ratio)
|
||||
num_frames = cfg.num_frames
|
||||
|
||||
# == build diffusion model ==
|
||||
input_size = (cfg.get("num_frames", None), *cfg.get("image_size", (None, None)))
|
||||
input_size = (num_frames, *image_size)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
model = (
|
||||
build_module(
|
||||
|
|
@ -106,10 +118,8 @@ def main():
|
|||
assert len(mask_strategy) == len(prompts), "Length of mask_strategy must be the same as prompts"
|
||||
|
||||
# == prepare arguments ==
|
||||
image_size = cfg.image_size
|
||||
num_frames = cfg.num_frames
|
||||
fps = cfg.fps
|
||||
save_fps = cfg.fps // cfg.get("frame_interval", 1)
|
||||
save_fps = fps // cfg.get("frame_interval", 1)
|
||||
multi_resolution = cfg.get("multi_resolution", None)
|
||||
batch_size = cfg.get("batch_size", 1)
|
||||
num_sample = cfg.get("num_sample", 1)
|
||||
|
|
|
|||
|
|
@ -1,144 +1,145 @@
|
|||
import os
|
||||
from pprint import pformat
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from mmengine.runner import set_random_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader
|
||||
from opensora.datasets import prepare_dataloader, save_sample
|
||||
from opensora.models.vae.losses import VAELoss
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import to_torch_dtype
|
||||
from opensora.utils.misc import create_logger, get_world_size, is_distributed, is_main_process, to_torch_dtype
|
||||
|
||||
|
||||
def main():
|
||||
# ======================================================
|
||||
# 1. cfg and init distributed env
|
||||
# ======================================================
|
||||
cfg = parse_configs(training=False)
|
||||
print(cfg)
|
||||
|
||||
# init distributed
|
||||
if os.environ.get("WORLD_SIZE", None):
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
else:
|
||||
pass
|
||||
|
||||
# ======================================================
|
||||
# 2. runtime variables
|
||||
# ======================================================
|
||||
torch.set_grad_enabled(False)
|
||||
# ======================================================
|
||||
# configs & runtime variables
|
||||
# ======================================================
|
||||
# == parse configs ==
|
||||
cfg = parse_configs(training=False)
|
||||
|
||||
# == device and dtype ==
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
cfg_dtype = cfg.get("dtype", "fp32")
|
||||
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
|
||||
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = to_torch_dtype(cfg.dtype)
|
||||
set_random_seed(seed=cfg.seed)
|
||||
|
||||
# == init distributed env ==
|
||||
if is_distributed():
|
||||
colossalai.launch_from_torch({})
|
||||
set_random_seed(seed=cfg.get("seed", 1024))
|
||||
|
||||
# == init logger ==
|
||||
logger = create_logger()
|
||||
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
|
||||
verbose = cfg.get("verbose", 1)
|
||||
|
||||
# ======================================================
|
||||
# 3. build dataset and dataloader
|
||||
# build dataset and dataloader
|
||||
# ======================================================
|
||||
logger.info("Building reconstruction dataset...")
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
batch_size = cfg.get("batch_size", 1)
|
||||
dataloader = prepare_dataloader(
|
||||
dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=batch_size,
|
||||
num_workers=cfg.get("num_workers", 4),
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
drop_last=False,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
distributed=is_distributed(),
|
||||
)
|
||||
print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})")
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size()
|
||||
print(f"Total batch size: {total_batch_size}")
|
||||
logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset))
|
||||
total_batch_size = batch_size * get_world_size()
|
||||
logger.info("Total batch size: %s", total_batch_size)
|
||||
|
||||
total_steps = len(dataloader)
|
||||
if cfg.get("num_samples", None) is not None:
|
||||
total_steps = min(int(cfg.num_samples // cfg.batch_size), total_steps)
|
||||
logger.info("limiting test dataset to %s", int(cfg.num_samples // cfg.batch_size) * cfg.batch_size)
|
||||
dataiter = iter(dataloader)
|
||||
|
||||
# ======================================================
|
||||
# 4. build model & load weights
|
||||
# build model & loss
|
||||
# ======================================================
|
||||
# 4.1. build model
|
||||
model = build_module(cfg.model, MODELS)
|
||||
model.to(device, dtype).eval()
|
||||
|
||||
# ======================================================
|
||||
# 5. inference
|
||||
# ======================================================
|
||||
cfg.save_dir
|
||||
|
||||
# define loss function
|
||||
logger.info("Building models...")
|
||||
model = build_module(cfg.model, MODELS).to(device, dtype).eval()
|
||||
vae_loss_fn = VAELoss(
|
||||
logvar_init=cfg.get("logvar_init", 0.0),
|
||||
perceptual_loss_weight=cfg.perceptual_loss_weight,
|
||||
kl_loss_weight=cfg.kl_loss_weight,
|
||||
perceptual_loss_weight=cfg.get("perceptual_loss_weight", 0.1),
|
||||
kl_loss_weight=cfg.get("kl_loss_weight", 1e-6),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# get total number of steps
|
||||
total_steps = len(dataloader)
|
||||
if cfg.max_test_samples is not None:
|
||||
total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps)
|
||||
print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}")
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
# ======================================================
|
||||
# inference
|
||||
# ======================================================
|
||||
# == global variables ==
|
||||
running_loss = running_nll = running_nll_z = 0.0
|
||||
loss_steps = 0
|
||||
|
||||
calc_std = cfg.get("calc_std", False)
|
||||
if calc_std:
|
||||
running_sum = 0.0
|
||||
running_sum_c = torch.zeros(model.out_channels, dtype=torch.float, device=device)
|
||||
running_var = 0.0
|
||||
running_var_c = torch.zeros(model.out_channels, dtype=torch.float, device=device)
|
||||
cal_stats = cfg.get("cal_stats", False)
|
||||
if cal_stats:
|
||||
num_samples = 0
|
||||
running_sum = running_var = 0.0
|
||||
running_sum_c = torch.zeros(model.out_channels, dtype=torch.float, device=device)
|
||||
running_var_c = torch.zeros(model.out_channels, dtype=torch.float, device=device)
|
||||
|
||||
# prepare arguments
|
||||
save_fps = cfg.get("fps", 24) // cfg.get("frame_interval", 1)
|
||||
|
||||
# Iter over the dataset
|
||||
with tqdm(
|
||||
range(total_steps),
|
||||
disable=not coordinator.is_master(),
|
||||
disable=not is_main_process() or verbose < 1,
|
||||
total=total_steps,
|
||||
initial=0,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
batch = next(dataloader_iter)
|
||||
batch = next(dataiter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
|
||||
# == vae encoding & decoding ===
|
||||
z, posterior, x_z = model.encode(x)
|
||||
x_rec, x_z_rec = model.decode(z, num_frames=x.size(2))
|
||||
x_ref = model.spatial_vae.decode(x_z)
|
||||
|
||||
# == check z shape ==
|
||||
input_size = x.shape[2:]
|
||||
latent_size = model.get_latent_size(input_size)
|
||||
assert list(z.shape[2:]) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}"
|
||||
|
||||
# ===== VAE =====
|
||||
z, posterior, x_z = model.encode(x)
|
||||
|
||||
# calc std
|
||||
if calc_std:
|
||||
# == calculate stats ==
|
||||
if cal_stats:
|
||||
num_samples += 1
|
||||
running_sum += z.mean().item()
|
||||
running_var += (z - running_sum / num_samples).pow(2).mean().item()
|
||||
|
||||
running_sum_c += z.mean(dim=(0, 2, 3, 4)).float()
|
||||
running_var_c += (
|
||||
(z - running_sum_c[None, :, None, None, None] / num_samples).pow(2).mean(dim=(0, 2, 3, 4)).float()
|
||||
)
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"mean": running_sum / num_samples,
|
||||
"std": (running_var / num_samples) ** 0.5,
|
||||
}
|
||||
)
|
||||
if num_samples % 100 == 0:
|
||||
print(
|
||||
" mean_c ",
|
||||
if verbose >= 1:
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"mean": running_sum / num_samples,
|
||||
"std": (running_var / num_samples) ** 0.5,
|
||||
}
|
||||
)
|
||||
if num_samples % cfg.get("log_stats_every", 100) == 0:
|
||||
logger.info(
|
||||
"VAE feature per channel stats: mean %s, var %s",
|
||||
(running_sum_c / num_samples).cpu().tolist(),
|
||||
"std_c ",
|
||||
(running_var_c / num_samples).sqrt().cpu().tolist(),
|
||||
)
|
||||
|
||||
assert list(z.shape[2:]) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}"
|
||||
x_rec, x_z_rec = model.decode(z, num_frames=x.size(2))
|
||||
model.spatial_vae.decode(x_z)
|
||||
|
||||
# loss calculation
|
||||
# == loss calculation ==
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
|
||||
nll_loss_z, _, _ = vae_loss_fn(x_z, x_z_rec, posterior, no_perceptual=True)
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss
|
||||
|
|
@ -147,24 +148,24 @@ def main():
|
|||
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
|
||||
running_nll_z = nll_loss_z.item() / loss_steps + running_nll_z * ((loss_steps - 1) / loss_steps)
|
||||
|
||||
# if not use_dist or coordinator.is_master():
|
||||
# ori_dir = f"{save_dir}_ori"
|
||||
# rec_dir = f"{save_dir}_rec"
|
||||
# ref_dir = f"{save_dir}_ref"
|
||||
# os.makedirs(ori_dir, exist_ok=True)
|
||||
# os.makedirs(rec_dir, exist_ok=True)
|
||||
# os.makedirs(ref_dir, exist_ok=True)
|
||||
# for idx, vid in enumerate(x):
|
||||
# pos = step * cfg.batch_size + idx
|
||||
# save_sample(vid, fps=cfg.fps, save_path=f"{ori_dir}/{pos:03d}")
|
||||
# save_sample(x_rec[idx], fps=cfg.fps, save_path=f"{rec_dir}/{pos:03d}")
|
||||
# save_sample(x_ref[idx], fps=cfg.fps, save_path=f"{ref_dir}/{pos:03d}")
|
||||
# == save samples ==
|
||||
save_dir = cfg.get("save_dir", None)
|
||||
if is_main_process() and save_dir is not None:
|
||||
ori_dir = f"{save_dir}_ori"
|
||||
rec_dir = f"{save_dir}_rec"
|
||||
ref_dir = f"{save_dir}_spatial"
|
||||
os.makedirs(ori_dir, exist_ok=True)
|
||||
os.makedirs(rec_dir, exist_ok=True)
|
||||
os.makedirs(ref_dir, exist_ok=True)
|
||||
for idx, vid in enumerate(x):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_sample(vid, fps=save_fps, save_path=f"{ori_dir}/{pos:03d}", verbose=verbose >= 2)
|
||||
save_sample(x_rec[idx], fps=save_fps, save_path=f"{rec_dir}/{pos:03d}", verbose=verbose >= 2)
|
||||
save_sample(x_ref[idx], fps=save_fps, save_path=f"{ref_dir}/{pos:03d}", verbose=verbose >= 2)
|
||||
|
||||
print("test vae loss:", running_loss)
|
||||
print("test nll loss:", running_nll)
|
||||
print("test nll_z loss:", running_nll_z)
|
||||
if calc_std:
|
||||
print("z std:", running_std_sum / num_samples)
|
||||
logger.info("VAE loss: %s", running_loss)
|
||||
logger.info("VAE nll loss: %s", running_nll)
|
||||
logger.info("VAE nll_z loss: %s", running_nll_z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -230,12 +230,12 @@ def main():
|
|||
for epoch in range(start_epoch, cfg_epochs):
|
||||
# == set dataloader to new epoch ==
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
dataiter = iter(dataloader)
|
||||
logger.info("Beginning epoch %s...", epoch)
|
||||
|
||||
# == training loop in an epoch ==
|
||||
with tqdm(
|
||||
enumerate(dataloader_iter, start=start_step),
|
||||
enumerate(dataiter, start=start_step),
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not coordinator.is_master(),
|
||||
total=num_steps_per_epoch,
|
||||
|
|
@ -253,7 +253,7 @@ def main():
|
|||
length = random.randint(1, x.size(2))
|
||||
x = x[:, :, :length, :, :]
|
||||
|
||||
# == vae encoding ===
|
||||
# == vae encoding & decoding ===
|
||||
x_rec, x_z_rec, z, posterior, x_z = model(x)
|
||||
|
||||
# == loss initialization ==
|
||||
|
|
|
|||
Loading…
Reference in a new issue