[wip] multiple resolution

This commit is contained in:
zhengzangw 2024-04-27 11:43:15 +00:00
parent f551321e49
commit d6a16a858e
19 changed files with 226 additions and 242 deletions

View file

@ -0,0 +1,10 @@
Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works.
A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.
Lego model, future rocket station, intricate details, high resolution, unreal engine, UHD
One giant, sharp, metal square mirror in the center of the frame, four young people on the foreground, background sunny palm oil planation, tropical, realistic style, photography, nostalgic, green tone, mysterious, dreamy, bright color.
Modern luxury contemporary luxury home interiors house, in the style of mimicking ruined materials, ray tracing, haunting houses, and stone, capture the essence of nature, gray and bronze, dynamic outdoor shots.
Over the shoulder game perspective, game screen of Diablo 4, Inside the gorgeous palace is the wet ground, The necromancer knelt before the king, and a horde of skeletons he summoned stood at his side, cinematic light.
A curvy timber house near a sea, designed by Zaha Hadid, represent the image of a cold, modern architecture, at night, white lighting, highly detailed.
Full body shot, a French woman, Photography, French Streets background, backlighting, rim light, Fujifilm.
Eiffel Tower was Made up of more than 2 million translucent straws to look like a cloud, with the bell tower at the top of the building, Michel installed huge foam-making machines in the forest to blow huge amounts of unpredictable wet clouds in the building's classic architecture.
A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures.

View file

@ -1,4 +1,3 @@
A small cactus with a happy face in the Sahara desert.
A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures.
A majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty.
A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene. In the foreground, a few cars can be seen driving along a winding road that cuts through the mountains. The cars are small compared to the vastness of the landscape, emphasizing the grandeur of the surroundings. The overall style of the video is a mix of adventure and tranquility, with the hot air balloons adding a touch of whimsy to the otherwise serene mountain landscape. The video is likely shot during the day, as the lighting is bright and even, casting soft shadows on the snow-covered mountains.

View file

@ -33,6 +33,7 @@ model = dict(
from_pretrained=None,
input_sq_size=512,
qk_norm=True,
qk_norm_legacy=True,
enable_flashattn=True,
enable_layernorm_kernel=True,
)

View file

@ -10,6 +10,7 @@ model = dict(
from_pretrained=None,
input_sq_size=512,
qk_norm=True,
qk_norm_legacy=True,
enable_flashattn=True,
enable_layernorm_kernel=True,
)

View file

@ -65,6 +65,7 @@ model = dict(
from_pretrained=None,
input_sq_size=512, # pretrained model is trained on 512x512
qk_norm=True,
qk_norm_legacy=True,
enable_flashattn=True,
enable_layernorm_kernel=True,
)

View file

@ -29,6 +29,7 @@ model = dict(
from_pretrained=None,
input_sq_size=512, # pretrained model is trained on 512x512
qk_norm=True,
qk_norm_legacy=True,
enable_flashattn=True,
enable_layernorm_kernel=True,
)

View file

@ -41,6 +41,7 @@ model = dict(
from_pretrained=None,
input_sq_size=512, # pretrained model is trained on 512x512
qk_norm=True,
qk_norm_legacy=True,
enable_flashattn=True,
enable_layernorm_kernel=True,
)

View file

@ -43,6 +43,7 @@ model = dict(
from_pretrained=None,
input_sq_size=512, # pretrained model is trained on 512x512
qk_norm=True,
qk_norm_legacy=True,
enable_flashattn=True,
enable_layernorm_kernel=True,
)

View file

@ -43,6 +43,7 @@ model = dict(
from_pretrained=None,
input_sq_size=512, # pretrained model is trained on 512x512
qk_norm=True,
qk_norm_legacy=True,
enable_flashattn=True,
enable_layernorm_kernel=True,
)

View file

@ -31,6 +31,7 @@ model = dict(
from_pretrained=None,
input_sq_size=512, # pretrained model is trained on 512x512
qk_norm=True,
qk_norm_legacy=True,
enable_flashattn=True,
enable_layernorm_kernel=True,
)

View file

@ -1,62 +0,0 @@
num_frames = 1
fps = 1
image_size = (2048, 2048)
multi_resolution = "STDiT2"
# Define model
# model = dict(
# type="STDiT2-XL/2",
# from_pretrained="/home/zhouyukun/data/models/PixArt-Sigma/PixArt-Sigma-XL-2-256x256.pth",
# input_sq_size=512,
# qk_norm=True,
# enable_flashattn=True,
# enable_layernorm_kernel=True,
# )
model = dict(
type="PixArt-Sigma-XL/2",
space_scale=4,
no_temporal_pos_emb=True,
from_pretrained="PixArt-Sigma-XL-2-2K-MS.pth",
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae"
)
text_encoder = dict(
type="t5",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
model_max_length=300,
cache_dir=None,
subfolder=True
)
scheduler = dict(
type="iddpm",
num_sampling_steps=250,
cfg_scale=7,
cfg_channel=3, # or None
)
# scheduler = dict(
# type="dpm-solver",
# num_sampling_steps=50,
# cfg_scale=4.0,
# )
dtype = "bf16"
# Condition
prompt_path = "./assets/texts/t2v_samples.txt"
prompt = None # prompt has higher priority than prompt_path
# Others
batch_size = 1
seed = 42
save_dir = "./samples/samples/"

View file

@ -0,0 +1,33 @@
num_frames = 1
fps = 1
image_size = (2688, 1408)
# image_size = (2048, 2048)
model = dict(
type="PixArt-Sigma-XL/2",
space_scale=4,
no_temporal_pos_emb=True,
from_pretrained="PixArt-Sigma-XL-2-2K-MS.pth",
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=300,
)
scheduler = dict(
type="dpm-solver",
num_sampling_steps=14,
cfg_scale=4.5,
)
dtype = "bf16"
# Others
batch_size = 1
seed = 42
prompt_path = "./assets/texts/t2i_sigma.txt"
save_dir = "./samples/samples/"

View file

@ -13,7 +13,7 @@ def get_h_w(a, ts, eps=1e-4):
return h, w
AR = [
AR = (
3 / 8,
9 / 21,
0.48,
@ -31,8 +31,8 @@ AR = [
17 / 9,
2 / 1,
1 / 0.48,
]
ARV = [0.375, 0.43, 0.48, 0.50, 0.53, 0.54, 0.56, 0.62, 0.67, 0.75, 1.0, 1.33, 1.50, 1.78, 1.89, 2.0, 2.08]
)
ARV = (0.375, 0.43, 0.48, 0.50, 0.53, 0.54, 0.56, 0.62, 0.67, 0.75, 1, 1.33, 1.50, 1.78, 1.89, 2, 2.08)
def get_aspect_ratios_dict(ts=360 * 640, ars=AR):
@ -61,6 +61,27 @@ ASPECT_RATIO_4K = {
"2.08": (4156, 1994),
}
# S = 3686400
ASPECT_RATIO_2K = {
"0.38": (1176, 3136),
"0.43": (1256, 2930),
"0.48": (1330, 2770),
"0.50": (1358, 2716),
"0.53": (1398, 2640),
"0.54": (1412, 2612),
"0.56": (1440, 2560), # base
"0.62": (1518, 2428),
"0.67": (1568, 2352),
"0.75": (1662, 2216),
"1.00": (1920, 1920),
"1.33": (2218, 1664),
"1.50": (2352, 1568),
"1.78": (2560, 1440),
"1.89": (2638, 1396),
"2.00": (2716, 1358),
"2.08": (2772, 1330),
}
# S = 2073600
ASPECT_RATIO_1080P = {
"0.38": (882, 2352),
@ -188,6 +209,94 @@ ASPECT_RATIO_144P = {
}
# from PixArt
# S = 8294400
ASPECT_RATIO_2880 = {
"0.25": (1408, 5760),
"0.26": (1408, 5568),
"0.27": (1408, 5376),
"0.28": (1408, 5184),
"0.32": (1600, 4992),
"0.33": (1600, 4800),
"0.34": (1600, 4672),
"0.4": (1792, 4480),
"0.42": (1792, 4288),
"0.47": (1920, 4096),
"0.49": (1920, 3904),
"0.51": (1920, 3776),
"0.55": (2112, 3840),
"0.59": (2112, 3584),
"0.68": (2304, 3392),
"0.72": (2304, 3200),
"0.78": (2496, 3200),
"0.83": (2496, 3008),
"0.89": (2688, 3008),
"0.93": (2688, 2880),
"1.0": (2880, 2880),
"1.07": (2880, 2688),
"1.12": (3008, 2688),
"1.21": (3008, 2496),
"1.28": (3200, 2496),
"1.39": (3200, 2304),
"1.47": (3392, 2304),
"1.7": (3584, 2112),
"1.82": (3840, 2112),
"2.03": (3904, 1920),
"2.13": (4096, 1920),
"2.39": (4288, 1792),
"2.5": (4480, 1792),
"2.92": (4672, 1600),
"3.0": (4800, 1600),
"3.12": (4992, 1600),
"3.68": (5184, 1408),
"3.82": (5376, 1408),
"3.95": (5568, 1408),
"4.0": (5760, 1408),
}
# S = 4194304
ASPECT_RATIO_2048 = {
"0.25": (1024, 4096),
"0.26": (1024, 3968),
"0.27": (1024, 3840),
"0.28": (1024, 3712),
"0.32": (1152, 3584),
"0.33": (1152, 3456),
"0.35": (1152, 3328),
"0.4": (1280, 3200),
"0.42": (1280, 3072),
"0.48": (1408, 2944),
"0.5": (1408, 2816),
"0.52": (1408, 2688),
"0.57": (1536, 2688),
"0.6": (1536, 2560),
"0.68": (1664, 2432),
"0.72": (1664, 2304),
"0.78": (1792, 2304),
"0.82": (1792, 2176),
"0.88": (1920, 2176),
"0.94": (1920, 2048),
"1.0": (2048, 2048),
"1.07": (2048, 1920),
"1.13": (2176, 1920),
"1.21": (2176, 1792),
"1.29": (2304, 1792),
"1.38": (2304, 1664),
"1.46": (2432, 1664),
"1.67": (2560, 1536),
"1.75": (2688, 1536),
"2.0": (2816, 1408),
"2.09": (2944, 1408),
"2.4": (3072, 1280),
"2.5": (3200, 1280),
"2.89": (3328, 1152),
"3.0": (3456, 1152),
"3.11": (3584, 1152),
"3.62": (3712, 1024),
"3.75": (3840, 1024),
"3.88": (3968, 1024),
"4.0": (4096, 1024),
}
# S = 1048576
ASPECT_RATIO_1024 = {
"0.25": (512, 2048),
@ -337,5 +446,8 @@ ASPECT_RATIOS = {
"720p": (921600, ASPECT_RATIO_720P),
"1024": (1048576, ASPECT_RATIO_1024),
"1080p": (2073600, ASPECT_RATIO_1080P),
"2k": (3686400, ASPECT_RATIO_2K),
"2048": (4194304, ASPECT_RATIO_2048),
"2880": (8294400, ASPECT_RATIO_2880),
"4k": (8294400, ASPECT_RATIO_4K),
}

View file

@ -229,7 +229,6 @@ class KVCompressAttention(nn.Module):
proj_drop: float = 0.0,
norm_layer: nn.Module = LlamaRMSNorm,
enable_flashattn: bool = False,
rope=None,
sampling="conv",
sr_ratio=1,
mem_eff_attention=False,
@ -260,11 +259,6 @@ class KVCompressAttention(nn.Module):
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = False
if rope is not None:
self.rope = True
self.rotary_emb = rope
self.mem_eff_attention = mem_eff_attention
self.attn_half = attn_half
@ -294,41 +288,28 @@ class KVCompressAttention(nn.Module):
def forward(self, x: torch.Tensor, mask=None, HW=None, block_id=None, **kwargs) -> torch.Tensor:
B, N, C = x.shape
new_N = N
if HW is None:
H = W = int(N**0.5)
else:
H, W = HW
H, W = HW
# flash attn is not memory efficient for small sequences, this is empirical
enable_flashattn = self.enable_flashattn and (N > B)
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
# WARNING: this may be a bug
if self.rope:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
q, k = self.q_norm(q), self.k_norm(k)
qkv = self.qkv(x).reshape(B, N, 3, C)
q, k, v = qkv.unbind(2)
dtype = q.dtype
# KV compression
if self.sr_ratio > 1:
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype)
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
q, k = self.q_norm(q), self.k_norm(k)
if enable_flashattn:
from flash_attn import flash_attn_func
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
x = flash_attn_func(
q,
k,
@ -343,8 +324,11 @@ class KVCompressAttention(nn.Module):
attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float("-inf"))
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
else:
# (B, N, #heads, #dim) -> (B, #heads, N, #dim)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
dtype = q.dtype
q = q * self.scale
attn = q @ k.transpose(-2, -1) # translate attn to float32

View file

@ -1,2 +1,2 @@
from .pixart import PixArt, PixArt_XL_2
from .pixart_sigma import PixArt_SigmaMS, PixArt_Sigma_XL_2
from .pixart_sigma import PixArt_Sigma_XL_2

View file

@ -64,9 +64,6 @@ class PixArtBlock(nn.Module):
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
qk_norm=False,
sampling="conv",
sr_ratio=1
):
super().__init__()
self.hidden_size = hidden_size
@ -86,9 +83,6 @@ class PixArtBlock(nn.Module):
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
qk_norm=qk_norm,
sr_ratio=sr_ratio,
sampling=sampling,
)
self.cross_attn = self.mha_cls(hidden_size, num_heads)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
@ -97,8 +91,6 @@ class PixArtBlock(nn.Module):
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
self.sampling = sampling
self.sr_ratio = sr_ratio
def forward(self, x, y, t, mask=None):
B, N, C = x.shape
@ -136,13 +128,11 @@ class PixArt(nn.Module):
model_max_length=120,
dtype=torch.float32,
freeze=None,
qk_norm=False,
space_scale=1.0,
time_scale=1.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
kv_compress_config=None,
):
super().__init__()
assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in this version."
@ -182,15 +172,6 @@ class PixArt(nn.Module):
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.kv_compress_config = kv_compress_config
if kv_compress_config is None:
self.kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList(
[
PixArtBlock(
@ -200,10 +181,6 @@ class PixArt(nn.Module):
drop_path=drop_path[i],
enable_flashattn=enable_flashattn,
enable_layernorm_kernel=enable_layernorm_kernel,
qk_norm=qk_norm,
sr_ratio=int(
self.kv_compress_config['scale_factor']) if i in self.kv_compress_config['kv_compress_layer'] else 1,
sampling=self.kv_compress_config['sampling'],
)
for i in range(depth)
]
@ -216,7 +193,7 @@ class PixArt(nn.Module):
if freeze == "text":
self.freeze_text()
def forward(self, x, timestep, y, mask=None, **kwargs):
def forward(self, x, timestep, y, mask=None):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
@ -341,24 +318,19 @@ class PixArtMS(PixArt):
self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
def forward(self, x, timestep, y, mask=None, height=None, width=None, ar=None, **kwargs):
def forward(self, x, timestep, y, mask=None, data_info=None):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
B = x.shape[0]
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
hw = torch.cat([height[:, None], width[:, None]], dim=1)
# 2. get aspect ratio
ar = ar.unsqueeze(1)
c_size = hw
ar = ar
c_size = data_info["hw"]
ar = data_info["ar"]
pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype)
# embedding

View file

@ -31,13 +31,10 @@ from timm.models.vision_transformer import Mlp
# from .builder import MODELS
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
KVCompressAttention,
MultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
SizeEmbedder,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
@ -45,7 +42,6 @@ from opensora.models.layers.blocks import (
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
KVCompressAttention
)
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
@ -67,19 +63,16 @@ class PixArtBlock(nn.Module):
enable_sequence_parallelism=False,
qk_norm=False,
sampling="conv",
sr_ratio=1
sr_ratio=1,
):
super().__init__()
self.hidden_size = hidden_size
self.enable_flashattn = enable_flashattn
self._enable_sequence_parallelism = enable_sequence_parallelism
assert not enable_sequence_parallelism, "Sequence parallelism is not supported in this version."
if enable_sequence_parallelism:
self.attn_cls = SeqParallelAttention
self.mha_cls = SeqParallelMultiHeadCrossAttention
else:
self.attn_cls = KVCompressAttention
self.mha_cls = MultiHeadCrossAttention
self.attn_cls = KVCompressAttention
self.mha_cls = MultiHeadCrossAttention
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = self.attn_cls(
@ -102,13 +95,15 @@ class PixArtBlock(nn.Module):
self.sampling = sampling
self.sr_ratio = sr_ratio
def forward(self, x, y, t, mask=None):
def forward(self, x, y, t, hw, mask=None):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.drop_path(
gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=hw).reshape(B, N, C)
)
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
@ -188,9 +183,9 @@ class PixArt_Sigma(nn.Module):
self.kv_compress_config = kv_compress_config
if kv_compress_config is None:
self.kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
"sampling": None,
"scale_factor": 1,
"kv_compress_layer": [],
}
self.blocks = nn.ModuleList(
@ -203,9 +198,12 @@ class PixArt_Sigma(nn.Module):
enable_flashattn=enable_flashattn,
enable_layernorm_kernel=enable_layernorm_kernel,
qk_norm=qk_norm,
sr_ratio=int(
self.kv_compress_config['scale_factor']) if i in self.kv_compress_config['kv_compress_layer'] else 1,
sampling=self.kv_compress_config['sampling'],
sr_ratio=(
int(self.kv_compress_config["scale_factor"])
if i in self.kv_compress_config["kv_compress_layer"]
else 1
),
sampling=self.kv_compress_config["sampling"],
)
for i in range(depth)
]
@ -218,7 +216,7 @@ class PixArt_Sigma(nn.Module):
if freeze == "text":
self.freeze_text()
def forward(self, x, timestep, y, mask=None, **kwargs):
def forward(self, x, timestep, y, mask=None):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
@ -228,11 +226,13 @@ class PixArt_Sigma(nn.Module):
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype)
hw = (x.shape[-2] // self.patch_size[-2], x.shape[-1] // self.patch_size[-1])
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
x = x + self.pos_embed
x = x + pos_embed.to(x.device)
if not self.no_temporal_pos_emb:
x = rearrange(x, "b t s d -> b s t d")
x = x + self.pos_embed_temporal
@ -255,7 +255,7 @@ class PixArt_Sigma(nn.Module):
# blocks
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens)
x = auto_grad_checkpoint(block, x, y, t0, hw, y_lens)
# final process
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
@ -334,88 +334,9 @@ class PixArt_Sigma(nn.Module):
nn.init.constant_(self.final_layer.linear.bias, 0)
@MODELS.register_module()
class PixArt_SigmaMS(PixArt_Sigma):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3"
self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
def forward(self, x, timestep, y, mask=None, height=None, width=None, ar=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
B = x.shape[0]
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
hw = torch.cat([height[:, None], width[:, None]], dim=1)
# 2. get aspect ratio
ar = ar.unsqueeze(1)
c_size = hw
ar = ar
pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
x = x + pos_embed.to(x.device)
if not self.no_temporal_pos_emb:
x = rearrange(x, "b t s d -> b s t d")
x = x + self.pos_embed_temporal
x = rearrange(x, "b s t d -> b (t s) d")
else:
x = rearrange(x, "b t s d -> b (t s) d")
t = self.t_embedder(timestep, dtype=x.dtype) # (N, D)
B = x.shape[0]
csize = self.csize_embedder(c_size, B)
ar = self.ar_embedder(ar, B)
t = t + torch.cat([csize, ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for block in self.blocks:
x = block(x, y, t0, y_lens)
# final process
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
@MODELS.register_module("PixArt-Sigma-XL/2")
def PixArt_Sigma_XL_2(from_pretrained=None, **kwargs):
model = PixArt_Sigma(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
@MODELS.register_module("PixArt-SigmaMS-XL/2")
def PixArtMS_XL_2(from_pretrained=None, **kwargs):
model = PixArt_SigmaMS(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model

View file

@ -32,7 +32,7 @@ from opensora.registry import MODELS
class T5Embedder:
available_models = ["DeepFloyd/t5-v1_1-xxl", "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"]
available_models = ["DeepFloyd/t5-v1_1-xxl"]
def __init__(
self,
@ -47,7 +47,6 @@ class T5Embedder:
use_offload_folder=None,
model_max_length=120,
local_files_only=False,
subfolder=None
):
self.device = torch.device(device)
self.torch_dtype = torch_dtype or torch.bfloat16
@ -104,14 +103,12 @@ class T5Embedder:
self.tokenizer = AutoTokenizer.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
subfolder="tokenizer" if subfolder else None,
local_files_only=local_files_only,
)
self.model = T5EncoderModel.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only,
subfolder="text_encoder" if subfolder else None,
**t5_model_kwargs,
).eval()
self.model_max_length = model_max_length
@ -148,7 +145,6 @@ class T5Encoder:
cache_dir=None,
shardformer=False,
local_files_only=False,
subfolder=None,
):
assert from_pretrained is not None, "Please specify the path to the T5 model"
@ -159,7 +155,6 @@ class T5Encoder:
cache_dir=cache_dir,
model_max_length=model_max_length,
local_files_only=local_files_only,
subfolder=subfolder,
)
self.t5.model.to(dtype=dtype)
self.y_embedder = None

View file

@ -32,10 +32,13 @@ pretrained_models = {
"OpenSora-v1-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth",
"OpenSora-v1-HQ-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x256x256.pth",
"OpenSora-v1-HQ-16x512x512.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x512x512.pth",
"PixArt-Sigma-XL-2-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth",
"PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth",
"PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth",
"PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth",
"PixArt-Sigma-XL-2-256x256.pth": hf_endpoint
+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth",
"PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint
+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth",
"PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint
+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth",
"PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth",
}
@ -51,7 +54,16 @@ def reparameter(ckpt, name=None, model=None):
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
del ckpt["temp_embed"]
if name in ["PixArt-XL-2-256x256.pth", "PixArt-XL-2-SAM-256x256.pth", "PixArt-XL-2-512x512.pth", "PixArt-Sigma-XL-2-256x256.pth", "PixArt-Sigma-XL-2-512-MS.pth", "PixArt-Sigma-XL-2-1024-MS.pth", "PixArt-Sigma-XL-2-2K-MS.pth"]:
if name in [
"PixArt-XL-2-256x256.pth",
"PixArt-XL-2-SAM-256x256.pth",
"PixArt-XL-2-512x512.pth",
"PixArt-XL-2-1024-MS.pth",
"PixArt-Sigma-XL-2-256x256.pth",
"PixArt-Sigma-XL-2-512-MS.pth",
"PixArt-Sigma-XL-2-1024-MS.pth",
"PixArt-Sigma-XL-2-2K-MS.pth",
]:
ckpt = ckpt["state_dict"]
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
if "pos_embed" in ckpt: