mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
[exp] update
This commit is contained in:
parent
1877a14abc
commit
15f9d702ed
|
|
@ -166,23 +166,23 @@ ASPECT_RATIO_240P = {
|
|||
|
||||
# S = 36864
|
||||
ASPECT_RATIO_144P = {
|
||||
"0.38": (118, 314),
|
||||
"0.43": (126, 294),
|
||||
"0.48": (134, 280),
|
||||
"0.50": (136, 272),
|
||||
"0.53": (140, 264),
|
||||
"0.54": (142, 262),
|
||||
"0.38": (117, 312),
|
||||
"0.43": (125, 291),
|
||||
"0.48": (133, 277),
|
||||
"0.50": (135, 270),
|
||||
"0.53": (139, 262),
|
||||
"0.54": (141, 260),
|
||||
"0.56": (144, 256), # base
|
||||
"0.62": (152, 244),
|
||||
"0.62": (151, 241),
|
||||
"0.67": (156, 234),
|
||||
"0.75": (166, 222),
|
||||
"0.75": (166, 221),
|
||||
"1.00": (192, 192),
|
||||
"1.33": (222, 166),
|
||||
"1.50": (236, 158),
|
||||
"1.33": (221, 165),
|
||||
"1.50": (235, 156),
|
||||
"1.78": (256, 144),
|
||||
"1.89": (264, 140),
|
||||
"2.00": (272, 136),
|
||||
"2.08": (278, 134),
|
||||
"1.89": (263, 139),
|
||||
"2.00": (271, 135),
|
||||
"2.08": (277, 132),
|
||||
}
|
||||
|
||||
# from PixArt
|
||||
|
|
|
|||
|
|
@ -211,13 +211,13 @@ def load(
|
|||
booster.load_model(model, os.path.join(load_dir, "model"))
|
||||
# ema is not boosted, so we don't use booster.load_model
|
||||
# ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt")))
|
||||
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"), strict=False))
|
||||
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")))
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
|
||||
if lr_scheduler is not None:
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
|
||||
running_states = load_json(os.path.join(load_dir, "running_states.json"))
|
||||
if sampler is not None:
|
||||
sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")))
|
||||
sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")), strict=False)
|
||||
dist.barrier()
|
||||
return (
|
||||
running_states["epoch"],
|
||||
|
|
|
|||
|
|
@ -229,6 +229,7 @@ def main():
|
|||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
||||
print(x.shape)
|
||||
y = batch.pop("text")
|
||||
# Visual and text encoding
|
||||
with torch.no_grad():
|
||||
|
|
|
|||
Loading…
Reference in a new issue