mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
Merge branch 'main' into dev/v1.2
This commit is contained in:
commit
f551321e49
|
|
@ -1,3 +1,4 @@
|
|||
A small cactus with a happy face in the Sahara desert.
|
||||
A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures.
|
||||
A majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty.
|
||||
A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene. In the foreground, a few cars can be seen driving along a winding road that cuts through the mountains. The cars are small compared to the vastness of the landscape, emphasizing the grandeur of the surroundings. The overall style of the video is a mix of adventure and tranquility, with the hot air balloons adding a touch of whimsy to the otherwise serene mountain landscape. The video is likely shot during the day, as the lighting is bright and even, casting soft shadows on the snow-covered mountains.
|
||||
|
|
|
|||
62
configs/opensora-v1-2/inference/1x2048x2048.py
Normal file
62
configs/opensora-v1-2/inference/1x2048x2048.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
num_frames = 1
|
||||
fps = 1
|
||||
image_size = (2048, 2048)
|
||||
multi_resolution = "STDiT2"
|
||||
|
||||
|
||||
# Define model
|
||||
# model = dict(
|
||||
# type="STDiT2-XL/2",
|
||||
# from_pretrained="/home/zhouyukun/data/models/PixArt-Sigma/PixArt-Sigma-XL-2-256x256.pth",
|
||||
# input_sq_size=512,
|
||||
# qk_norm=True,
|
||||
# enable_flashattn=True,
|
||||
# enable_layernorm_kernel=True,
|
||||
# )
|
||||
|
||||
model = dict(
|
||||
type="PixArt-Sigma-XL/2",
|
||||
space_scale=4,
|
||||
no_temporal_pos_emb=True,
|
||||
from_pretrained="PixArt-Sigma-XL-2-2K-MS.pth",
|
||||
)
|
||||
|
||||
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae"
|
||||
)
|
||||
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
model_max_length=300,
|
||||
cache_dir=None,
|
||||
subfolder=True
|
||||
)
|
||||
|
||||
|
||||
scheduler = dict(
|
||||
type="iddpm",
|
||||
num_sampling_steps=250,
|
||||
cfg_scale=7,
|
||||
cfg_channel=3, # or None
|
||||
)
|
||||
|
||||
# scheduler = dict(
|
||||
# type="dpm-solver",
|
||||
# num_sampling_steps=50,
|
||||
# cfg_scale=4.0,
|
||||
# )
|
||||
|
||||
dtype = "bf16"
|
||||
|
||||
# Condition
|
||||
prompt_path = "./assets/texts/t2v_samples.txt"
|
||||
prompt = None # prompt has higher priority than prompt_path
|
||||
|
||||
# Others
|
||||
batch_size = 1
|
||||
seed = 42
|
||||
save_dir = "./samples/samples/"
|
||||
|
|
@ -218,6 +218,152 @@ class Attention(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class KVCompressAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = LlamaRMSNorm,
|
||||
enable_flashattn: bool = False,
|
||||
rope=None,
|
||||
sampling="conv",
|
||||
sr_ratio=1,
|
||||
mem_eff_attention=False,
|
||||
attn_half=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.enable_flashattn = enable_flashattn
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
self.sampling = sampling
|
||||
if sr_ratio > 1 and sampling == "conv":
|
||||
# Avg Conv Init.
|
||||
self.sr = nn.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio)
|
||||
self.sr.weight.data.fill_(1 / sr_ratio**2)
|
||||
self.sr.bias.data.zero_()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.rope = False
|
||||
if rope is not None:
|
||||
self.rope = True
|
||||
self.rotary_emb = rope
|
||||
|
||||
self.mem_eff_attention = mem_eff_attention
|
||||
self.attn_half = attn_half
|
||||
|
||||
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
|
||||
if sampling is None or scale_factor == 1:
|
||||
return tensor
|
||||
B, N, C = tensor.shape
|
||||
|
||||
if sampling == "uniform_every":
|
||||
return tensor[:, ::scale_factor], int(N // scale_factor)
|
||||
|
||||
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
||||
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
|
||||
new_N = new_H * new_W
|
||||
|
||||
if sampling == "ave":
|
||||
tensor = F.interpolate(tensor, scale_factor=1 / scale_factor, mode="nearest").permute(0, 2, 3, 1)
|
||||
elif sampling == "uniform":
|
||||
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
|
||||
elif sampling == "conv":
|
||||
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
|
||||
tensor = self.norm(tensor)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return tensor.reshape(B, new_N, C).contiguous(), new_N
|
||||
|
||||
def forward(self, x: torch.Tensor, mask=None, HW=None, block_id=None, **kwargs) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
|
||||
new_N = N
|
||||
if HW is None:
|
||||
H = W = int(N**0.5)
|
||||
else:
|
||||
H, W = HW
|
||||
|
||||
# flash attn is not memory efficient for small sequences, this is empirical
|
||||
enable_flashattn = self.enable_flashattn and (N > B)
|
||||
qkv = self.qkv(x)
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
|
||||
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
# WARNING: this may be a bug
|
||||
if self.rope:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.sr_ratio > 1:
|
||||
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
|
||||
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
|
||||
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||
|
||||
if enable_flashattn:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
x = flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
|
||||
elif self.mem_eff_attention:
|
||||
attn_bias = None
|
||||
if mask is not None:
|
||||
attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
|
||||
attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float("-inf"))
|
||||
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
|
||||
|
||||
else:
|
||||
dtype = q.dtype
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1) # translate attn to float32
|
||||
if not self.attn_half:
|
||||
attn = attn.to(torch.float32)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = attn.to(dtype) # cast back attn to original dtype
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x_output_shape = (B, N, C)
|
||||
if not enable_flashattn:
|
||||
x = x.transpose(1, 2)
|
||||
x = x.reshape(x_output_shape)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class SeqParallelAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from .pixart import PixArt, PixArt_XL_2
|
||||
from .pixart_sigma import PixArt_SigmaMS, PixArt_Sigma_XL_2
|
||||
|
|
|
|||
|
|
@ -64,6 +64,9 @@ class PixArtBlock(nn.Module):
|
|||
enable_flashattn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
qk_norm=False,
|
||||
sampling="conv",
|
||||
sr_ratio=1
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
|
@ -83,6 +86,9 @@ class PixArtBlock(nn.Module):
|
|||
num_heads=num_heads,
|
||||
qkv_bias=True,
|
||||
enable_flashattn=enable_flashattn,
|
||||
qk_norm=qk_norm,
|
||||
sr_ratio=sr_ratio,
|
||||
sampling=sampling,
|
||||
)
|
||||
self.cross_attn = self.mha_cls(hidden_size, num_heads)
|
||||
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
|
||||
|
|
@ -91,6 +97,8 @@ class PixArtBlock(nn.Module):
|
|||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
|
||||
self.sampling = sampling
|
||||
self.sr_ratio = sr_ratio
|
||||
|
||||
def forward(self, x, y, t, mask=None):
|
||||
B, N, C = x.shape
|
||||
|
|
@ -128,11 +136,13 @@ class PixArt(nn.Module):
|
|||
model_max_length=120,
|
||||
dtype=torch.float32,
|
||||
freeze=None,
|
||||
qk_norm=False,
|
||||
space_scale=1.0,
|
||||
time_scale=1.0,
|
||||
enable_flashattn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
kv_compress_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in this version."
|
||||
|
|
@ -172,6 +182,15 @@ class PixArt(nn.Module):
|
|||
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
|
||||
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
|
||||
self.kv_compress_config = kv_compress_config
|
||||
if kv_compress_config is None:
|
||||
self.kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
PixArtBlock(
|
||||
|
|
@ -181,6 +200,10 @@ class PixArt(nn.Module):
|
|||
drop_path=drop_path[i],
|
||||
enable_flashattn=enable_flashattn,
|
||||
enable_layernorm_kernel=enable_layernorm_kernel,
|
||||
qk_norm=qk_norm,
|
||||
sr_ratio=int(
|
||||
self.kv_compress_config['scale_factor']) if i in self.kv_compress_config['kv_compress_layer'] else 1,
|
||||
sampling=self.kv_compress_config['sampling'],
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
|
|
@ -193,7 +216,7 @@ class PixArt(nn.Module):
|
|||
if freeze == "text":
|
||||
self.freeze_text()
|
||||
|
||||
def forward(self, x, timestep, y, mask=None):
|
||||
def forward(self, x, timestep, y, mask=None, **kwargs):
|
||||
"""
|
||||
Forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
|
|
@ -318,19 +341,24 @@ class PixArtMS(PixArt):
|
|||
self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
|
||||
self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
|
||||
|
||||
def forward(self, x, timestep, y, mask=None, data_info=None):
|
||||
def forward(self, x, timestep, y, mask=None, height=None, width=None, ar=None, **kwargs):
|
||||
"""
|
||||
Forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
B = x.shape[0]
|
||||
x = x.to(self.dtype)
|
||||
timestep = timestep.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
|
||||
c_size = data_info["hw"]
|
||||
ar = data_info["ar"]
|
||||
hw = torch.cat([height[:, None], width[:, None]], dim=1)
|
||||
# 2. get aspect ratio
|
||||
ar = ar.unsqueeze(1)
|
||||
|
||||
c_size = hw
|
||||
ar = ar
|
||||
pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype)
|
||||
|
||||
# embedding
|
||||
|
|
|
|||
421
opensora/models/pixart/pixart_sigma.py
Normal file
421
opensora/models/pixart/pixart_sigma.py
Normal file
|
|
@ -0,0 +1,421 @@
|
|||
# Adapted from PixArt
|
||||
#
|
||||
# Copyright (C) 2023 PixArt-alpha/PixArt-alpha
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
|
||||
# DiT: https://github.com/facebookresearch/DiT/tree/main
|
||||
# --------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from timm.models.layers import DropPath
|
||||
from timm.models.vision_transformer import Mlp
|
||||
|
||||
# from .builder import MODELS
|
||||
from opensora.acceleration.checkpoint import auto_grad_checkpoint
|
||||
from opensora.models.layers.blocks import (
|
||||
Attention,
|
||||
CaptionEmbedder,
|
||||
MultiHeadCrossAttention,
|
||||
PatchEmbed3D,
|
||||
SeqParallelAttention,
|
||||
SeqParallelMultiHeadCrossAttention,
|
||||
SizeEmbedder,
|
||||
T2IFinalLayer,
|
||||
TimestepEmbedder,
|
||||
approx_gelu,
|
||||
get_1d_sincos_pos_embed,
|
||||
get_2d_sincos_pos_embed,
|
||||
get_layernorm,
|
||||
t2i_modulate,
|
||||
KVCompressAttention
|
||||
)
|
||||
from opensora.registry import MODELS
|
||||
from opensora.utils.ckpt_utils import load_checkpoint
|
||||
|
||||
|
||||
class PixArtBlock(nn.Module):
|
||||
"""
|
||||
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
drop_path=0.0,
|
||||
enable_flashattn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
qk_norm=False,
|
||||
sampling="conv",
|
||||
sr_ratio=1
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.enable_flashattn = enable_flashattn
|
||||
self._enable_sequence_parallelism = enable_sequence_parallelism
|
||||
|
||||
if enable_sequence_parallelism:
|
||||
self.attn_cls = SeqParallelAttention
|
||||
self.mha_cls = SeqParallelMultiHeadCrossAttention
|
||||
else:
|
||||
self.attn_cls = KVCompressAttention
|
||||
self.mha_cls = MultiHeadCrossAttention
|
||||
|
||||
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
|
||||
self.attn = self.attn_cls(
|
||||
hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=True,
|
||||
enable_flashattn=enable_flashattn,
|
||||
qk_norm=qk_norm,
|
||||
sr_ratio=sr_ratio,
|
||||
sampling=sampling,
|
||||
attn_half=True,
|
||||
)
|
||||
self.cross_attn = self.mha_cls(hidden_size, num_heads)
|
||||
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
|
||||
self.sampling = sampling
|
||||
self.sr_ratio = sr_ratio
|
||||
|
||||
def forward(self, x, y, t, mask=None):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + t.reshape(B, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PixArt_Sigma(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=(1, 32, 32),
|
||||
in_channels=4,
|
||||
patch_size=(1, 2, 2),
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
pred_sigma=True,
|
||||
drop_path: float = 0.0,
|
||||
no_temporal_pos_emb=False,
|
||||
caption_channels=4096,
|
||||
model_max_length=120,
|
||||
dtype=torch.float32,
|
||||
freeze=None,
|
||||
qk_norm=False,
|
||||
space_scale=1.0,
|
||||
time_scale=1.0,
|
||||
enable_flashattn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
kv_compress_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in this version."
|
||||
self.pred_sigma = pred_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.patch_size = patch_size
|
||||
self.input_size = input_size
|
||||
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
|
||||
self.num_patches = num_patches
|
||||
self.num_temporal = input_size[0] // patch_size[0]
|
||||
self.num_spatial = num_patches // self.num_temporal
|
||||
self.base_size = int(np.sqrt(self.num_spatial))
|
||||
self.num_heads = num_heads
|
||||
self.dtype = dtype
|
||||
self.no_temporal_pos_emb = no_temporal_pos_emb
|
||||
self.depth = depth
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.enable_flashattn = enable_flashattn
|
||||
self.enable_layernorm_kernel = enable_layernorm_kernel
|
||||
self.space_scale = space_scale
|
||||
self.time_scale = time_scale
|
||||
|
||||
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels,
|
||||
hidden_size=hidden_size,
|
||||
uncond_prob=class_dropout_prob,
|
||||
act_layer=approx_gelu,
|
||||
token_num=model_max_length,
|
||||
)
|
||||
|
||||
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
|
||||
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
|
||||
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
|
||||
self.kv_compress_config = kv_compress_config
|
||||
if kv_compress_config is None:
|
||||
self.kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
PixArtBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop_path=drop_path[i],
|
||||
enable_flashattn=enable_flashattn,
|
||||
enable_layernorm_kernel=enable_layernorm_kernel,
|
||||
qk_norm=qk_norm,
|
||||
sr_ratio=int(
|
||||
self.kv_compress_config['scale_factor']) if i in self.kv_compress_config['kv_compress_layer'] else 1,
|
||||
sampling=self.kv_compress_config['sampling'],
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
|
||||
|
||||
self.initialize_weights()
|
||||
if freeze is not None:
|
||||
assert freeze in ["text"]
|
||||
if freeze == "text":
|
||||
self.freeze_text()
|
||||
|
||||
def forward(self, x, timestep, y, mask=None, **kwargs):
|
||||
"""
|
||||
Forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
x = x.to(self.dtype)
|
||||
timestep = timestep.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
|
||||
# embedding
|
||||
x = self.x_embedder(x) # (B, N, D)
|
||||
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
|
||||
x = x + self.pos_embed
|
||||
if not self.no_temporal_pos_emb:
|
||||
x = rearrange(x, "b t s d -> b s t d")
|
||||
x = x + self.pos_embed_temporal
|
||||
x = rearrange(x, "b s t d -> b (t s) d")
|
||||
else:
|
||||
x = rearrange(x, "b t s d -> b (t s) d")
|
||||
|
||||
t = self.t_embedder(timestep, dtype=x.dtype) # (N, D)
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, 1, L, D)
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
|
||||
# blocks
|
||||
for block in self.blocks:
|
||||
x = auto_grad_checkpoint(block, x, y, t0, y_lens)
|
||||
|
||||
# final process
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
|
||||
# cast to float32 for better accuracy
|
||||
x = x.to(torch.float32)
|
||||
return x
|
||||
|
||||
def unpatchify(self, x):
|
||||
c = self.out_channels
|
||||
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
|
||||
pt, ph, pw = self.patch_size
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
|
||||
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
|
||||
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
||||
return imgs
|
||||
|
||||
def get_spatial_pos_embed(self, grid_size=None):
|
||||
if grid_size is None:
|
||||
grid_size = self.input_size[1:]
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
self.hidden_size,
|
||||
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
|
||||
scale=self.space_scale,
|
||||
base_size=self.base_size,
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
|
||||
return pos_embed
|
||||
|
||||
def get_temporal_pos_embed(self):
|
||||
pos_embed = get_1d_sincos_pos_embed(
|
||||
self.hidden_size,
|
||||
self.input_size[0] // self.patch_size[0],
|
||||
scale=self.time_scale,
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
|
||||
return pos_embed
|
||||
|
||||
def freeze_text(self):
|
||||
for n, p in self.named_parameters():
|
||||
if "cross_attn" in n:
|
||||
p.requires_grad = False
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
nn.init.normal_(self.t_block[1].weight, std=0.02)
|
||||
|
||||
# Initialize caption embedding MLP:
|
||||
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
|
||||
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
|
||||
|
||||
# Zero-out adaLN modulation layers in PixArt blocks:
|
||||
for block in self.blocks:
|
||||
nn.init.constant_(block.cross_attn.proj.weight, 0)
|
||||
nn.init.constant_(block.cross_attn.proj.bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PixArt_SigmaMS(PixArt_Sigma):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3"
|
||||
self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
|
||||
self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
|
||||
|
||||
def forward(self, x, timestep, y, mask=None, height=None, width=None, ar=None, **kwargs):
|
||||
"""
|
||||
Forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
B = x.shape[0]
|
||||
x = x.to(self.dtype)
|
||||
timestep = timestep.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
|
||||
hw = torch.cat([height[:, None], width[:, None]], dim=1)
|
||||
# 2. get aspect ratio
|
||||
ar = ar.unsqueeze(1)
|
||||
|
||||
c_size = hw
|
||||
ar = ar
|
||||
pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype)
|
||||
|
||||
# embedding
|
||||
x = self.x_embedder(x) # (B, N, D)
|
||||
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
|
||||
x = x + pos_embed.to(x.device)
|
||||
if not self.no_temporal_pos_emb:
|
||||
x = rearrange(x, "b t s d -> b s t d")
|
||||
x = x + self.pos_embed_temporal
|
||||
x = rearrange(x, "b s t d -> b (t s) d")
|
||||
else:
|
||||
x = rearrange(x, "b t s d -> b (t s) d")
|
||||
|
||||
t = self.t_embedder(timestep, dtype=x.dtype) # (N, D)
|
||||
B = x.shape[0]
|
||||
csize = self.csize_embedder(c_size, B)
|
||||
ar = self.ar_embedder(ar, B)
|
||||
t = t + torch.cat([csize, ar], dim=1)
|
||||
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, 1, L, D)
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
|
||||
# blocks
|
||||
for block in self.blocks:
|
||||
x = block(x, y, t0, y_lens)
|
||||
|
||||
# final process
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
|
||||
# cast to float32 for better accuracy
|
||||
x = x.to(torch.float32)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module("PixArt-Sigma-XL/2")
|
||||
def PixArt_Sigma_XL_2(from_pretrained=None, **kwargs):
|
||||
model = PixArt_Sigma(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(model, from_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module("PixArt-SigmaMS-XL/2")
|
||||
def PixArtMS_XL_2(from_pretrained=None, **kwargs):
|
||||
model = PixArt_SigmaMS(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(model, from_pretrained)
|
||||
return model
|
||||
|
|
@ -32,7 +32,7 @@ from opensora.registry import MODELS
|
|||
|
||||
|
||||
class T5Embedder:
|
||||
available_models = ["DeepFloyd/t5-v1_1-xxl"]
|
||||
available_models = ["DeepFloyd/t5-v1_1-xxl", "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -47,6 +47,7 @@ class T5Embedder:
|
|||
use_offload_folder=None,
|
||||
model_max_length=120,
|
||||
local_files_only=False,
|
||||
subfolder=None
|
||||
):
|
||||
self.device = torch.device(device)
|
||||
self.torch_dtype = torch_dtype or torch.bfloat16
|
||||
|
|
@ -103,12 +104,14 @@ class T5Embedder:
|
|||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
from_pretrained,
|
||||
cache_dir=cache_dir,
|
||||
subfolder="tokenizer" if subfolder else None,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
self.model = T5EncoderModel.from_pretrained(
|
||||
from_pretrained,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
subfolder="text_encoder" if subfolder else None,
|
||||
**t5_model_kwargs,
|
||||
).eval()
|
||||
self.model_max_length = model_max_length
|
||||
|
|
@ -145,6 +148,7 @@ class T5Encoder:
|
|||
cache_dir=None,
|
||||
shardformer=False,
|
||||
local_files_only=False,
|
||||
subfolder=None,
|
||||
):
|
||||
assert from_pretrained is not None, "Please specify the path to the T5 model"
|
||||
|
||||
|
|
@ -155,6 +159,7 @@ class T5Encoder:
|
|||
cache_dir=cache_dir,
|
||||
model_max_length=model_max_length,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
self.t5.model.to(dtype=dtype)
|
||||
self.y_embedder = None
|
||||
|
|
|
|||
|
|
@ -8,10 +8,11 @@ from opensora.registry import MODELS
|
|||
|
||||
@MODELS.register_module()
|
||||
class VideoAutoencoderKL(nn.Module):
|
||||
def __init__(self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False):
|
||||
def __init__(self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None):
|
||||
super().__init__()
|
||||
self.module = AutoencoderKL.from_pretrained(
|
||||
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only
|
||||
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
self.out_channels = self.module.config.latent_channels
|
||||
self.patch_size = (1, 8, 8)
|
||||
|
|
|
|||
|
|
@ -70,7 +70,6 @@ class IDDPM(SpacedDiffusion):
|
|||
model_args["y"] = torch.cat([model_args["y"], y_null], 0)
|
||||
if additional_args is not None:
|
||||
model_args.update(additional_args)
|
||||
|
||||
forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale, cfg_channel=self.cfg_channel)
|
||||
samples = self.p_sample_loop(
|
||||
forward,
|
||||
|
|
|
|||
|
|
@ -32,10 +32,17 @@ pretrained_models = {
|
|||
"OpenSora-v1-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth",
|
||||
"OpenSora-v1-HQ-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x256x256.pth",
|
||||
"OpenSora-v1-HQ-16x512x512.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x512x512.pth",
|
||||
"PixArt-Sigma-XL-2-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth",
|
||||
"PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth",
|
||||
"PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth",
|
||||
"PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth",
|
||||
}
|
||||
|
||||
|
||||
def reparameter(ckpt, name=None, model=None):
|
||||
name = os.path.basename(name)
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print("loading pretrained model:", name)
|
||||
if name in ["DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"]:
|
||||
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
|
||||
del ckpt["pos_embed"]
|
||||
|
|
@ -44,11 +51,11 @@ def reparameter(ckpt, name=None, model=None):
|
|||
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
|
||||
del ckpt["pos_embed"]
|
||||
del ckpt["temp_embed"]
|
||||
if name in ["PixArt-XL-2-256x256.pth", "PixArt-XL-2-SAM-256x256.pth", "PixArt-XL-2-512x512.pth"]:
|
||||
if name in ["PixArt-XL-2-256x256.pth", "PixArt-XL-2-SAM-256x256.pth", "PixArt-XL-2-512x512.pth", "PixArt-Sigma-XL-2-256x256.pth", "PixArt-Sigma-XL-2-512-MS.pth", "PixArt-Sigma-XL-2-1024-MS.pth", "PixArt-Sigma-XL-2-2K-MS.pth"]:
|
||||
ckpt = ckpt["state_dict"]
|
||||
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
|
||||
del ckpt["pos_embed"]
|
||||
|
||||
if "pos_embed" in ckpt:
|
||||
del ckpt["pos_embed"]
|
||||
# no need pos_embed
|
||||
if "pos_embed_temporal" in ckpt:
|
||||
del ckpt["pos_embed_temporal"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue