[exp] update

This commit is contained in:
Zangwei Zheng 2024-04-19 17:43:44 +08:00
parent 1877a14abc
commit 15f9d702ed
3 changed files with 16 additions and 15 deletions

View file

@ -166,23 +166,23 @@ ASPECT_RATIO_240P = {
# S = 36864
ASPECT_RATIO_144P = {
"0.38": (118, 314),
"0.43": (126, 294),
"0.48": (134, 280),
"0.50": (136, 272),
"0.53": (140, 264),
"0.54": (142, 262),
"0.38": (117, 312),
"0.43": (125, 291),
"0.48": (133, 277),
"0.50": (135, 270),
"0.53": (139, 262),
"0.54": (141, 260),
"0.56": (144, 256), # base
"0.62": (152, 244),
"0.62": (151, 241),
"0.67": (156, 234),
"0.75": (166, 222),
"0.75": (166, 221),
"1.00": (192, 192),
"1.33": (222, 166),
"1.50": (236, 158),
"1.33": (221, 165),
"1.50": (235, 156),
"1.78": (256, 144),
"1.89": (264, 140),
"2.00": (272, 136),
"2.08": (278, 134),
"1.89": (263, 139),
"2.00": (271, 135),
"2.08": (277, 132),
}
# from PixArt

View file

@ -211,13 +211,13 @@ def load(
booster.load_model(model, os.path.join(load_dir, "model"))
# ema is not boosted, so we don't use booster.load_model
# ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt")))
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"), strict=False))
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")))
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
if lr_scheduler is not None:
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(os.path.join(load_dir, "running_states.json"))
if sampler is not None:
sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")))
sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")), strict=False)
dist.barrier()
return (
running_states["epoch"],

View file

@ -229,6 +229,7 @@ def main():
) as pbar:
for step, batch in pbar:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
print(x.shape)
y = batch.pop("text")
# Visual and text encoding
with torch.no_grad():