Merge branch 'main' into dev/v1.2

This commit is contained in:
zhengzangw 2024-04-27 11:40:56 +00:00
commit f551321e49
10 changed files with 682 additions and 11 deletions

View file

@ -1,3 +1,4 @@
A small cactus with a happy face in the Sahara desert.
A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures.
A majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty.
A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene. In the foreground, a few cars can be seen driving along a winding road that cuts through the mountains. The cars are small compared to the vastness of the landscape, emphasizing the grandeur of the surroundings. The overall style of the video is a mix of adventure and tranquility, with the hot air balloons adding a touch of whimsy to the otherwise serene mountain landscape. The video is likely shot during the day, as the lighting is bright and even, casting soft shadows on the snow-covered mountains.

View file

@ -0,0 +1,62 @@
num_frames = 1
fps = 1
image_size = (2048, 2048)
multi_resolution = "STDiT2"
# Define model
# model = dict(
# type="STDiT2-XL/2",
# from_pretrained="/home/zhouyukun/data/models/PixArt-Sigma/PixArt-Sigma-XL-2-256x256.pth",
# input_sq_size=512,
# qk_norm=True,
# enable_flashattn=True,
# enable_layernorm_kernel=True,
# )
model = dict(
type="PixArt-Sigma-XL/2",
space_scale=4,
no_temporal_pos_emb=True,
from_pretrained="PixArt-Sigma-XL-2-2K-MS.pth",
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae"
)
text_encoder = dict(
type="t5",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
model_max_length=300,
cache_dir=None,
subfolder=True
)
scheduler = dict(
type="iddpm",
num_sampling_steps=250,
cfg_scale=7,
cfg_channel=3, # or None
)
# scheduler = dict(
# type="dpm-solver",
# num_sampling_steps=50,
# cfg_scale=4.0,
# )
dtype = "bf16"
# Condition
prompt_path = "./assets/texts/t2v_samples.txt"
prompt = None # prompt has higher priority than prompt_path
# Others
batch_size = 1
seed = 42
save_dir = "./samples/samples/"

View file

@ -218,6 +218,152 @@ class Attention(nn.Module):
return x
class KVCompressAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = LlamaRMSNorm,
enable_flashattn: bool = False,
rope=None,
sampling="conv",
sr_ratio=1,
mem_eff_attention=False,
attn_half=False,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn = enable_flashattn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.sr_ratio = sr_ratio
self.sampling = sampling
if sr_ratio > 1 and sampling == "conv":
# Avg Conv Init.
self.sr = nn.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio)
self.sr.weight.data.fill_(1 / sr_ratio**2)
self.sr.bias.data.zero_()
self.norm = nn.LayerNorm(dim)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = False
if rope is not None:
self.rope = True
self.rotary_emb = rope
self.mem_eff_attention = mem_eff_attention
self.attn_half = attn_half
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
if sampling is None or scale_factor == 1:
return tensor
B, N, C = tensor.shape
if sampling == "uniform_every":
return tensor[:, ::scale_factor], int(N // scale_factor)
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
new_N = new_H * new_W
if sampling == "ave":
tensor = F.interpolate(tensor, scale_factor=1 / scale_factor, mode="nearest").permute(0, 2, 3, 1)
elif sampling == "uniform":
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
elif sampling == "conv":
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
tensor = self.norm(tensor)
else:
raise ValueError
return tensor.reshape(B, new_N, C).contiguous(), new_N
def forward(self, x: torch.Tensor, mask=None, HW=None, block_id=None, **kwargs) -> torch.Tensor:
B, N, C = x.shape
new_N = N
if HW is None:
H = W = int(N**0.5)
else:
H, W = HW
# flash attn is not memory efficient for small sequences, this is empirical
enable_flashattn = self.enable_flashattn and (N > B)
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
# WARNING: this may be a bug
if self.rope:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
q, k = self.q_norm(q), self.k_norm(k)
if self.sr_ratio > 1:
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
if enable_flashattn:
from flash_attn import flash_attn_func
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
elif self.mem_eff_attention:
attn_bias = None
if mask is not None:
attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float("-inf"))
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
else:
dtype = q.dtype
q = q * self.scale
attn = q @ k.transpose(-2, -1) # translate attn to float32
if not self.attn_half:
attn = attn.to(torch.float32)
attn = attn.softmax(dim=-1)
attn = attn.to(dtype) # cast back attn to original dtype
attn = self.attn_drop(attn)
x = attn @ v
x_output_shape = (B, N, C)
if not enable_flashattn:
x = x.transpose(1, 2)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SeqParallelAttention(Attention):
def __init__(
self,

View file

@ -1 +1,2 @@
from .pixart import PixArt, PixArt_XL_2
from .pixart_sigma import PixArt_SigmaMS, PixArt_Sigma_XL_2

View file

@ -64,6 +64,9 @@ class PixArtBlock(nn.Module):
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
qk_norm=False,
sampling="conv",
sr_ratio=1
):
super().__init__()
self.hidden_size = hidden_size
@ -83,6 +86,9 @@ class PixArtBlock(nn.Module):
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
qk_norm=qk_norm,
sr_ratio=sr_ratio,
sampling=sampling,
)
self.cross_attn = self.mha_cls(hidden_size, num_heads)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
@ -91,6 +97,8 @@ class PixArtBlock(nn.Module):
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
self.sampling = sampling
self.sr_ratio = sr_ratio
def forward(self, x, y, t, mask=None):
B, N, C = x.shape
@ -128,11 +136,13 @@ class PixArt(nn.Module):
model_max_length=120,
dtype=torch.float32,
freeze=None,
qk_norm=False,
space_scale=1.0,
time_scale=1.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
kv_compress_config=None,
):
super().__init__()
assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in this version."
@ -172,6 +182,15 @@ class PixArt(nn.Module):
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.kv_compress_config = kv_compress_config
if kv_compress_config is None:
self.kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList(
[
PixArtBlock(
@ -181,6 +200,10 @@ class PixArt(nn.Module):
drop_path=drop_path[i],
enable_flashattn=enable_flashattn,
enable_layernorm_kernel=enable_layernorm_kernel,
qk_norm=qk_norm,
sr_ratio=int(
self.kv_compress_config['scale_factor']) if i in self.kv_compress_config['kv_compress_layer'] else 1,
sampling=self.kv_compress_config['sampling'],
)
for i in range(depth)
]
@ -193,7 +216,7 @@ class PixArt(nn.Module):
if freeze == "text":
self.freeze_text()
def forward(self, x, timestep, y, mask=None):
def forward(self, x, timestep, y, mask=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
@ -318,19 +341,24 @@ class PixArtMS(PixArt):
self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
def forward(self, x, timestep, y, mask=None, data_info=None):
def forward(self, x, timestep, y, mask=None, height=None, width=None, ar=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
B = x.shape[0]
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
c_size = data_info["hw"]
ar = data_info["ar"]
hw = torch.cat([height[:, None], width[:, None]], dim=1)
# 2. get aspect ratio
ar = ar.unsqueeze(1)
c_size = hw
ar = ar
pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype)
# embedding

View file

@ -0,0 +1,421 @@
# Adapted from PixArt
#
# Copyright (C) 2023 PixArt-alpha/PixArt-alpha
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# DiT: https://github.com/facebookresearch/DiT/tree/main
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
# from .builder import MODELS
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
MultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
SizeEmbedder,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
KVCompressAttention
)
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
class PixArtBlock(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(
self,
hidden_size,
num_heads,
mlp_ratio=4.0,
drop_path=0.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
qk_norm=False,
sampling="conv",
sr_ratio=1
):
super().__init__()
self.hidden_size = hidden_size
self.enable_flashattn = enable_flashattn
self._enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.attn_cls = SeqParallelAttention
self.mha_cls = SeqParallelMultiHeadCrossAttention
else:
self.attn_cls = KVCompressAttention
self.mha_cls = MultiHeadCrossAttention
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
qk_norm=qk_norm,
sr_ratio=sr_ratio,
sampling=sampling,
attn_half=True,
)
self.cross_attn = self.mha_cls(hidden_size, num_heads)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
self.sampling = sampling
self.sr_ratio = sr_ratio
def forward(self, x, y, t, mask=None):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
@MODELS.register_module()
class PixArt_Sigma(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=(1, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path: float = 0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
dtype=torch.float32,
freeze=None,
qk_norm=False,
space_scale=1.0,
time_scale=1.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
kv_compress_config=None,
):
super().__init__()
assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in this version."
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.base_size = int(np.sqrt(self.num_spatial))
self.num_heads = num_heads
self.dtype = dtype
self.no_temporal_pos_emb = no_temporal_pos_emb
self.depth = depth
self.mlp_ratio = mlp_ratio
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.space_scale = space_scale
self.time_scale = time_scale
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
uncond_prob=class_dropout_prob,
act_layer=approx_gelu,
token_num=model_max_length,
)
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.kv_compress_config = kv_compress_config
if kv_compress_config is None:
self.kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList(
[
PixArtBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
drop_path=drop_path[i],
enable_flashattn=enable_flashattn,
enable_layernorm_kernel=enable_layernorm_kernel,
qk_norm=qk_norm,
sr_ratio=int(
self.kv_compress_config['scale_factor']) if i in self.kv_compress_config['kv_compress_layer'] else 1,
sampling=self.kv_compress_config['sampling'],
)
for i in range(depth)
]
)
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
self.initialize_weights()
if freeze is not None:
assert freeze in ["text"]
if freeze == "text":
self.freeze_text()
def forward(self, x, timestep, y, mask=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
x = x + self.pos_embed
if not self.no_temporal_pos_emb:
x = rearrange(x, "b t s d -> b s t d")
x = x + self.pos_embed_temporal
x = rearrange(x, "b s t d -> b (t s) d")
else:
x = rearrange(x, "b t s d -> b (t s) d")
t = self.t_embedder(timestep, dtype=x.dtype) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens)
# final process
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def unpatchify(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
base_size=self.base_size,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@MODELS.register_module()
class PixArt_SigmaMS(PixArt_Sigma):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3"
self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
def forward(self, x, timestep, y, mask=None, height=None, width=None, ar=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
B = x.shape[0]
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
hw = torch.cat([height[:, None], width[:, None]], dim=1)
# 2. get aspect ratio
ar = ar.unsqueeze(1)
c_size = hw
ar = ar
pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
x = x + pos_embed.to(x.device)
if not self.no_temporal_pos_emb:
x = rearrange(x, "b t s d -> b s t d")
x = x + self.pos_embed_temporal
x = rearrange(x, "b s t d -> b (t s) d")
else:
x = rearrange(x, "b t s d -> b (t s) d")
t = self.t_embedder(timestep, dtype=x.dtype) # (N, D)
B = x.shape[0]
csize = self.csize_embedder(c_size, B)
ar = self.ar_embedder(ar, B)
t = t + torch.cat([csize, ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for block in self.blocks:
x = block(x, y, t0, y_lens)
# final process
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
@MODELS.register_module("PixArt-Sigma-XL/2")
def PixArt_Sigma_XL_2(from_pretrained=None, **kwargs):
model = PixArt_Sigma(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
@MODELS.register_module("PixArt-SigmaMS-XL/2")
def PixArtMS_XL_2(from_pretrained=None, **kwargs):
model = PixArt_SigmaMS(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model

View file

@ -32,7 +32,7 @@ from opensora.registry import MODELS
class T5Embedder:
available_models = ["DeepFloyd/t5-v1_1-xxl"]
available_models = ["DeepFloyd/t5-v1_1-xxl", "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"]
def __init__(
self,
@ -47,6 +47,7 @@ class T5Embedder:
use_offload_folder=None,
model_max_length=120,
local_files_only=False,
subfolder=None
):
self.device = torch.device(device)
self.torch_dtype = torch_dtype or torch.bfloat16
@ -103,12 +104,14 @@ class T5Embedder:
self.tokenizer = AutoTokenizer.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
subfolder="tokenizer" if subfolder else None,
local_files_only=local_files_only,
)
self.model = T5EncoderModel.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only,
subfolder="text_encoder" if subfolder else None,
**t5_model_kwargs,
).eval()
self.model_max_length = model_max_length
@ -145,6 +148,7 @@ class T5Encoder:
cache_dir=None,
shardformer=False,
local_files_only=False,
subfolder=None,
):
assert from_pretrained is not None, "Please specify the path to the T5 model"
@ -155,6 +159,7 @@ class T5Encoder:
cache_dir=cache_dir,
model_max_length=model_max_length,
local_files_only=local_files_only,
subfolder=subfolder,
)
self.t5.model.to(dtype=dtype)
self.y_embedder = None

View file

@ -8,10 +8,11 @@ from opensora.registry import MODELS
@MODELS.register_module()
class VideoAutoencoderKL(nn.Module):
def __init__(self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False):
def __init__(self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None):
super().__init__()
self.module = AutoencoderKL.from_pretrained(
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only,
subfolder=subfolder,
)
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)

View file

@ -70,7 +70,6 @@ class IDDPM(SpacedDiffusion):
model_args["y"] = torch.cat([model_args["y"], y_null], 0)
if additional_args is not None:
model_args.update(additional_args)
forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale, cfg_channel=self.cfg_channel)
samples = self.p_sample_loop(
forward,

View file

@ -32,10 +32,17 @@ pretrained_models = {
"OpenSora-v1-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth",
"OpenSora-v1-HQ-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x256x256.pth",
"OpenSora-v1-HQ-16x512x512.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x512x512.pth",
"PixArt-Sigma-XL-2-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth",
"PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth",
"PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth",
"PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth",
}
def reparameter(ckpt, name=None, model=None):
name = os.path.basename(name)
if not dist.is_initialized() or dist.get_rank() == 0:
print("loading pretrained model:", name)
if name in ["DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"]:
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
@ -44,11 +51,11 @@ def reparameter(ckpt, name=None, model=None):
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
del ckpt["temp_embed"]
if name in ["PixArt-XL-2-256x256.pth", "PixArt-XL-2-SAM-256x256.pth", "PixArt-XL-2-512x512.pth"]:
if name in ["PixArt-XL-2-256x256.pth", "PixArt-XL-2-SAM-256x256.pth", "PixArt-XL-2-512x512.pth", "PixArt-Sigma-XL-2-256x256.pth", "PixArt-Sigma-XL-2-512-MS.pth", "PixArt-Sigma-XL-2-1024-MS.pth", "PixArt-Sigma-XL-2-2K-MS.pth"]:
ckpt = ckpt["state_dict"]
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
if "pos_embed" in ckpt:
del ckpt["pos_embed"]
# no need pos_embed
if "pos_embed_temporal" in ckpt:
del ckpt["pos_embed_temporal"]