Open-Sora/opensora/utils/ckpt_utils.py

293 lines
12 KiB
Python
Raw Normal View History

2024-03-15 15:06:36 +01:00
import functools
import json
import operator
import os
from typing import Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.booster import Booster
2024-03-17 05:17:28 +01:00
from colossalai.checkpoint_io import GeneralCheckpointIO
2024-03-15 15:06:36 +01:00
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torchvision.datasets.utils import download_url
2024-05-09 10:07:56 +02:00
from .misc import get_logger
hf_endpoint = os.environ.get("HF_ENDPOINT")
if hf_endpoint is None:
hf_endpoint = "https://huggingface.co"
2024-03-15 15:06:36 +01:00
pretrained_models = {
"DiT-XL-2-512x512.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt",
"DiT-XL-2-256x256.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt",
"Latte-XL-2-256x256-ucf101.pt": hf_endpoint + "/maxin-cn/Latte/resolve/main/ucf101.pt",
"PixArt-XL-2-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-256x256.pth",
"PixArt-XL-2-SAM-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-SAM-256x256.pth",
"PixArt-XL-2-512x512.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-512x512.pth",
"PixArt-XL-2-1024-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth",
"OpenSora-v1-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth",
"OpenSora-v1-HQ-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x256x256.pth",
"OpenSora-v1-HQ-16x512x512.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x512x512.pth",
2024-04-27 13:43:15 +02:00
"PixArt-Sigma-XL-2-256x256.pth": hf_endpoint
+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth",
"PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint
+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth",
"PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint
+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth",
"PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth",
2024-03-15 15:06:36 +01:00
}
2024-03-31 17:44:37 +02:00
def reparameter(ckpt, name=None, model=None):
2024-06-09 14:01:26 +02:00
model_name = name
name = os.path.basename(name)
if not dist.is_initialized() or dist.get_rank() == 0:
2024-06-09 14:01:26 +02:00
get_logger().info("loading pretrained model: %s", model_name)
2024-04-16 10:45:06 +02:00
if name in ["DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"]:
2024-03-15 15:06:36 +01:00
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
2024-04-16 10:45:06 +02:00
if name in ["Latte-XL-2-256x256-ucf101.pt"]:
2024-03-15 15:06:36 +01:00
ckpt = ckpt["ema"]
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
del ckpt["temp_embed"]
2024-04-27 13:43:15 +02:00
if name in [
"PixArt-XL-2-256x256.pth",
"PixArt-XL-2-SAM-256x256.pth",
"PixArt-XL-2-512x512.pth",
"PixArt-XL-2-1024-MS.pth",
"PixArt-Sigma-XL-2-256x256.pth",
"PixArt-Sigma-XL-2-512-MS.pth",
"PixArt-Sigma-XL-2-1024-MS.pth",
"PixArt-Sigma-XL-2-2K-MS.pth",
]:
2024-03-15 15:06:36 +01:00
ckpt = ckpt["state_dict"]
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
if "pos_embed" in ckpt:
del ckpt["pos_embed"]
2024-04-30 07:36:41 +02:00
2024-05-07 10:27:21 +02:00
if name in [
"PixArt-1B-2.pth",
]:
2024-04-30 07:36:41 +02:00
ckpt = ckpt["state_dict"]
if "pos_embed" in ckpt:
del ckpt["pos_embed"]
2024-03-31 17:44:37 +02:00
2024-04-16 10:45:06 +02:00
# no need pos_embed
if "pos_embed_temporal" in ckpt:
del ckpt["pos_embed_temporal"]
if "pos_embed" in ckpt:
del ckpt["pos_embed"]
2024-04-09 04:23:28 +02:00
# different text length
if "y_embedder.y_embedding" in ckpt:
2024-04-09 07:55:55 +02:00
if ckpt["y_embedder.y_embedding"].shape[0] < model.y_embedder.y_embedding.shape[0]:
2024-05-09 10:07:56 +02:00
get_logger().info(
"Extend y_embedding from %s to %s",
ckpt["y_embedder.y_embedding"].shape[0],
model.y_embedder.y_embedding.shape[0],
2024-04-16 10:45:06 +02:00
)
2024-03-31 17:44:37 +02:00
additional_length = model.y_embedder.y_embedding.shape[0] - ckpt["y_embedder.y_embedding"].shape[0]
2024-04-16 10:45:06 +02:00
new_y_embedding = torch.zeros(additional_length, model.y_embedder.y_embedding.shape[1])
new_y_embedding[:] = ckpt["y_embedder.y_embedding"][-1]
2024-03-31 17:44:37 +02:00
ckpt["y_embedder.y_embedding"] = torch.cat([ckpt["y_embedder.y_embedding"], new_y_embedding], dim=0)
2024-04-09 07:55:55 +02:00
elif ckpt["y_embedder.y_embedding"].shape[0] > model.y_embedder.y_embedding.shape[0]:
2024-05-09 10:07:56 +02:00
get_logger().info(
"Shrink y_embedding from %s to %s",
ckpt["y_embedder.y_embedding"].shape[0],
model.y_embedder.y_embedding.shape[0],
2024-04-16 10:45:06 +02:00
)
2024-04-09 04:23:28 +02:00
ckpt["y_embedder.y_embedding"] = ckpt["y_embedder.y_embedding"][: model.y_embedder.y_embedding.shape[0]]
2024-04-29 08:00:14 +02:00
# stdit3 special case
if type(model).__name__ == "STDiT3" and "PixArt-Sigma" in name:
ckpt_keys = list(ckpt.keys())
for key in ckpt_keys:
if "blocks." in key:
ckpt[key.replace("blocks.", "spatial_blocks.")] = ckpt[key]
del ckpt[key]
2024-04-16 10:45:06 +02:00
2024-03-15 15:06:36 +01:00
return ckpt
2024-03-31 17:44:37 +02:00
def find_model(model_name, model=None):
2024-03-15 15:06:36 +01:00
"""
Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
"""
if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
2024-03-31 17:44:37 +02:00
model_ckpt = download_model(model_name)
model_ckpt = reparameter(model_ckpt, model_name, model=model)
2024-03-15 15:06:36 +01:00
else: # Load a custom DiT checkpoint:
assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}"
2024-04-09 07:55:55 +02:00
model_ckpt = torch.load(model_name, map_location=lambda storage, loc: storage)
model_ckpt = reparameter(model_ckpt, model_name, model=model)
return model_ckpt
2024-03-15 15:06:36 +01:00
def download_model(model_name=None, local_path=None, url=None):
2024-03-15 15:06:36 +01:00
"""
Downloads a pre-trained DiT model from the web.
"""
if model_name is not None:
assert model_name in pretrained_models
local_path = f"pretrained_models/{model_name}"
web_path = pretrained_models[model_name]
else:
assert local_path is not None
assert url is not None
web_path = url
2024-03-15 15:06:36 +01:00
if not os.path.isfile(local_path):
os.makedirs("pretrained_models", exist_ok=True)
dir_name = os.path.dirname(local_path)
file_name = os.path.basename(local_path)
download_url(web_path, dir_name, file_name)
2024-03-15 15:06:36 +01:00
model = torch.load(local_path, map_location=lambda storage, loc: storage)
return model
2024-06-17 08:59:33 +02:00
def load_from_sharded_state_dict(model, ckpt_path, model_name="model", strict=False):
2024-03-17 05:17:28 +01:00
ckpt_io = GeneralCheckpointIO()
2024-06-17 08:59:33 +02:00
ckpt_io.load_model(model, os.path.join(ckpt_path, model_name), strict=strict)
2024-03-15 15:06:36 +01:00
2024-03-15 15:06:36 +01:00
def model_sharding(model: torch.nn.Module):
global_rank = dist.get_rank()
world_size = dist.get_world_size()
for _, param in model.named_parameters():
padding_size = (world_size - param.numel() % world_size) % world_size
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // world_size)
splited_params = splited_params[global_rank]
param.data = splited_params
2024-05-09 07:53:19 +02:00
def model_gathering(model: torch.nn.Module, model_shape_dict: dict):
global_rank = dist.get_rank()
global_size = dist.get_world_size()
for name, param in model.named_parameters():
all_params = [torch.empty_like(param.data) for _ in range(global_size)]
dist.all_gather(all_params, param.data, group=dist.group.WORLD)
if int(global_rank) == 0:
all_params = torch.cat(all_params)
param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
dist.barrier()
2024-03-15 15:06:36 +01:00
def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor:
return tensor[: functools.reduce(operator.mul, original_shape)]
def record_model_param_shape(model: torch.nn.Module) -> dict:
param_shape = {}
for name, param in model.named_parameters():
param_shape[name] = param.shape
return param_shape
2024-06-17 08:59:33 +02:00
def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", strict=False):
2024-05-09 07:53:19 +02:00
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
state_dict = find_model(ckpt_path, model=model)
2024-06-17 08:59:33 +02:00
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
2024-05-09 10:07:56 +02:00
get_logger().info("Missing keys: %s", missing_keys)
get_logger().info("Unexpected keys: %s", unexpected_keys)
2024-05-09 07:53:19 +02:00
elif os.path.isdir(ckpt_path):
2024-06-17 08:59:33 +02:00
load_from_sharded_state_dict(model, ckpt_path, model_name, strict=strict)
2024-05-09 10:07:56 +02:00
get_logger().info("Model checkpoint loaded from %s", ckpt_path)
2024-05-09 07:53:19 +02:00
if save_as_pt:
save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt")
torch.save(model.state_dict(), save_path)
2024-05-09 10:07:56 +02:00
get_logger().info("Model checkpoint saved to %s", save_path)
2024-05-09 07:53:19 +02:00
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
2024-05-09 10:07:56 +02:00
def load_json(file_path: str):
with open(file_path, "r") as f:
return json.load(f)
def save_json(data, file_path: str):
with open(file_path, "w") as f:
json.dump(data, f, indent=4)
2024-05-09 07:53:19 +02:00
# save and load for training
2024-03-15 15:06:36 +01:00
def save(
booster: Booster,
save_dir: str,
2024-05-09 07:53:19 +02:00
model: nn.Module = None,
ema: nn.Module = None,
optimizer: Optimizer = None,
lr_scheduler: _LRScheduler = None,
sampler=None,
2024-05-09 07:53:19 +02:00
epoch: int = None,
step: int = None,
global_step: int = None,
batch_size: int = None,
2024-03-15 15:06:36 +01:00
):
save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
2024-05-09 07:53:19 +02:00
if model is not None:
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
if optimizer is not None:
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
2024-03-15 15:06:36 +01:00
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
2024-05-09 07:53:19 +02:00
if dist.get_rank() == 0:
running_states = {
"epoch": epoch,
"step": step,
"global_step": global_step,
2024-05-20 10:40:45 +02:00
"batch_size": batch_size,
2024-05-09 07:53:19 +02:00
}
2024-03-15 15:06:36 +01:00
save_json(running_states, os.path.join(save_dir, "running_states.json"))
2024-05-09 07:53:19 +02:00
if ema is not None:
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
if sampler is not None:
2024-05-09 07:53:19 +02:00
# only for VariableVideoBatchSampler
torch.save(sampler.state_dict(step), os.path.join(save_dir, "sampler"))
2024-03-15 15:06:36 +01:00
dist.barrier()
2024-05-20 10:40:45 +02:00
return save_dir
2024-03-15 15:06:36 +01:00
def load(
booster: Booster,
load_dir: str,
2024-05-09 07:53:19 +02:00
model: nn.Module = None,
ema: nn.Module = None,
optimizer: Optimizer = None,
lr_scheduler: _LRScheduler = None,
sampler=None,
2024-03-15 15:06:36 +01:00
) -> Tuple[int, int, int]:
2024-05-09 07:53:19 +02:00
assert os.path.exists(load_dir), f"Checkpoint directory {load_dir} does not exist"
assert os.path.exists(os.path.join(load_dir, "running_states.json")), "running_states.json does not exist"
running_states = load_json(os.path.join(load_dir, "running_states.json"))
if model is not None:
booster.load_model(model, os.path.join(load_dir, "model"))
if ema is not None:
# ema is not boosted, so we don't use booster.load_model
ema.load_state_dict(
torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")),
strict=False,
)
if optimizer is not None:
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
2024-03-15 15:06:36 +01:00
if lr_scheduler is not None:
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
if sampler is not None:
2024-04-19 11:56:02 +02:00
sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")))
2024-03-15 15:06:36 +01:00
dist.barrier()
2024-05-09 07:53:19 +02:00
return (
running_states["epoch"],
running_states["step"],
)