diff --git a/opensora/datasets/aspect.py b/opensora/datasets/aspect.py index 0a541a0..89052ef 100644 --- a/opensora/datasets/aspect.py +++ b/opensora/datasets/aspect.py @@ -166,23 +166,23 @@ ASPECT_RATIO_240P = { # S = 36864 ASPECT_RATIO_144P = { - "0.38": (118, 314), - "0.43": (126, 294), - "0.48": (134, 280), - "0.50": (136, 272), - "0.53": (140, 264), - "0.54": (142, 262), + "0.38": (117, 312), + "0.43": (125, 291), + "0.48": (133, 277), + "0.50": (135, 270), + "0.53": (139, 262), + "0.54": (141, 260), "0.56": (144, 256), # base - "0.62": (152, 244), + "0.62": (151, 241), "0.67": (156, 234), - "0.75": (166, 222), + "0.75": (166, 221), "1.00": (192, 192), - "1.33": (222, 166), - "1.50": (236, 158), + "1.33": (221, 165), + "1.50": (235, 156), "1.78": (256, 144), - "1.89": (264, 140), - "2.00": (272, 136), - "2.08": (278, 134), + "1.89": (263, 139), + "2.00": (271, 135), + "2.08": (277, 132), } # from PixArt diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index 022f9ea..b8bdf7d 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -211,13 +211,13 @@ def load( booster.load_model(model, os.path.join(load_dir, "model")) # ema is not boosted, so we don't use booster.load_model # ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"))) - ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"), strict=False)) + ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"))) booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) if lr_scheduler is not None: booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) running_states = load_json(os.path.join(load_dir, "running_states.json")) if sampler is not None: - sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler"))) + sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")), strict=False) dist.barrier() return ( running_states["epoch"], diff --git a/scripts/train.py b/scripts/train.py index 1884fb3..6724532 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -229,6 +229,7 @@ def main(): ) as pbar: for step, batch in pbar: x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] + print(x.shape) y = batch.pop("text") # Visual and text encoding with torch.no_grad():