mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
[wip] multiple resolution
This commit is contained in:
parent
f551321e49
commit
d6a16a858e
10
assets/texts/t2i_sigma.txt
Normal file
10
assets/texts/t2i_sigma.txt
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works.
|
||||
A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.
|
||||
Lego model, future rocket station, intricate details, high resolution, unreal engine, UHD
|
||||
One giant, sharp, metal square mirror in the center of the frame, four young people on the foreground, background sunny palm oil planation, tropical, realistic style, photography, nostalgic, green tone, mysterious, dreamy, bright color.
|
||||
Modern luxury contemporary luxury home interiors house, in the style of mimicking ruined materials, ray tracing, haunting houses, and stone, capture the essence of nature, gray and bronze, dynamic outdoor shots.
|
||||
Over the shoulder game perspective, game screen of Diablo 4, Inside the gorgeous palace is the wet ground, The necromancer knelt before the king, and a horde of skeletons he summoned stood at his side, cinematic light.
|
||||
A curvy timber house near a sea, designed by Zaha Hadid, represent the image of a cold, modern architecture, at night, white lighting, highly detailed.
|
||||
Full body shot, a French woman, Photography, French Streets background, backlighting, rim light, Fujifilm.
|
||||
Eiffel Tower was Made up of more than 2 million translucent straws to look like a cloud, with the bell tower at the top of the building, Michel installed huge foam-making machines in the forest to blow huge amounts of unpredictable wet clouds in the building's classic architecture.
|
||||
A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures.
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
A small cactus with a happy face in the Sahara desert.
|
||||
A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures.
|
||||
A majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty.
|
||||
A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene. In the foreground, a few cars can be seen driving along a winding road that cuts through the mountains. The cars are small compared to the vastness of the landscape, emphasizing the grandeur of the surroundings. The overall style of the video is a mix of adventure and tranquility, with the hot air balloons adding a touch of whimsy to the otherwise serene mountain landscape. The video is likely shot during the day, as the lighting is bright and even, casting soft shadows on the snow-covered mountains.
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ model = dict(
|
|||
from_pretrained=None,
|
||||
input_sq_size=512,
|
||||
qk_norm=True,
|
||||
qk_norm_legacy=True,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ model = dict(
|
|||
from_pretrained=None,
|
||||
input_sq_size=512,
|
||||
qk_norm=True,
|
||||
qk_norm_legacy=True,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ model = dict(
|
|||
from_pretrained=None,
|
||||
input_sq_size=512, # pretrained model is trained on 512x512
|
||||
qk_norm=True,
|
||||
qk_norm_legacy=True,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ model = dict(
|
|||
from_pretrained=None,
|
||||
input_sq_size=512, # pretrained model is trained on 512x512
|
||||
qk_norm=True,
|
||||
qk_norm_legacy=True,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ model = dict(
|
|||
from_pretrained=None,
|
||||
input_sq_size=512, # pretrained model is trained on 512x512
|
||||
qk_norm=True,
|
||||
qk_norm_legacy=True,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ model = dict(
|
|||
from_pretrained=None,
|
||||
input_sq_size=512, # pretrained model is trained on 512x512
|
||||
qk_norm=True,
|
||||
qk_norm_legacy=True,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ model = dict(
|
|||
from_pretrained=None,
|
||||
input_sq_size=512, # pretrained model is trained on 512x512
|
||||
qk_norm=True,
|
||||
qk_norm_legacy=True,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ model = dict(
|
|||
from_pretrained=None,
|
||||
input_sq_size=512, # pretrained model is trained on 512x512
|
||||
qk_norm=True,
|
||||
qk_norm_legacy=True,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,62 +0,0 @@
|
|||
num_frames = 1
|
||||
fps = 1
|
||||
image_size = (2048, 2048)
|
||||
multi_resolution = "STDiT2"
|
||||
|
||||
|
||||
# Define model
|
||||
# model = dict(
|
||||
# type="STDiT2-XL/2",
|
||||
# from_pretrained="/home/zhouyukun/data/models/PixArt-Sigma/PixArt-Sigma-XL-2-256x256.pth",
|
||||
# input_sq_size=512,
|
||||
# qk_norm=True,
|
||||
# enable_flashattn=True,
|
||||
# enable_layernorm_kernel=True,
|
||||
# )
|
||||
|
||||
model = dict(
|
||||
type="PixArt-Sigma-XL/2",
|
||||
space_scale=4,
|
||||
no_temporal_pos_emb=True,
|
||||
from_pretrained="PixArt-Sigma-XL-2-2K-MS.pth",
|
||||
)
|
||||
|
||||
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae"
|
||||
)
|
||||
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
model_max_length=300,
|
||||
cache_dir=None,
|
||||
subfolder=True
|
||||
)
|
||||
|
||||
|
||||
scheduler = dict(
|
||||
type="iddpm",
|
||||
num_sampling_steps=250,
|
||||
cfg_scale=7,
|
||||
cfg_channel=3, # or None
|
||||
)
|
||||
|
||||
# scheduler = dict(
|
||||
# type="dpm-solver",
|
||||
# num_sampling_steps=50,
|
||||
# cfg_scale=4.0,
|
||||
# )
|
||||
|
||||
dtype = "bf16"
|
||||
|
||||
# Condition
|
||||
prompt_path = "./assets/texts/t2v_samples.txt"
|
||||
prompt = None # prompt has higher priority than prompt_path
|
||||
|
||||
# Others
|
||||
batch_size = 1
|
||||
seed = 42
|
||||
save_dir = "./samples/samples/"
|
||||
33
configs/pixart/inference/1x2048MS.py
Normal file
33
configs/pixart/inference/1x2048MS.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
num_frames = 1
|
||||
fps = 1
|
||||
image_size = (2688, 1408)
|
||||
# image_size = (2048, 2048)
|
||||
|
||||
model = dict(
|
||||
type="PixArt-Sigma-XL/2",
|
||||
space_scale=4,
|
||||
no_temporal_pos_emb=True,
|
||||
from_pretrained="PixArt-Sigma-XL-2-2K-MS.pth",
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=300,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="dpm-solver",
|
||||
num_sampling_steps=14,
|
||||
cfg_scale=4.5,
|
||||
)
|
||||
dtype = "bf16"
|
||||
|
||||
# Others
|
||||
batch_size = 1
|
||||
seed = 42
|
||||
prompt_path = "./assets/texts/t2i_sigma.txt"
|
||||
save_dir = "./samples/samples/"
|
||||
|
|
@ -13,7 +13,7 @@ def get_h_w(a, ts, eps=1e-4):
|
|||
return h, w
|
||||
|
||||
|
||||
AR = [
|
||||
AR = (
|
||||
3 / 8,
|
||||
9 / 21,
|
||||
0.48,
|
||||
|
|
@ -31,8 +31,8 @@ AR = [
|
|||
17 / 9,
|
||||
2 / 1,
|
||||
1 / 0.48,
|
||||
]
|
||||
ARV = [0.375, 0.43, 0.48, 0.50, 0.53, 0.54, 0.56, 0.62, 0.67, 0.75, 1.0, 1.33, 1.50, 1.78, 1.89, 2.0, 2.08]
|
||||
)
|
||||
ARV = (0.375, 0.43, 0.48, 0.50, 0.53, 0.54, 0.56, 0.62, 0.67, 0.75, 1, 1.33, 1.50, 1.78, 1.89, 2, 2.08)
|
||||
|
||||
|
||||
def get_aspect_ratios_dict(ts=360 * 640, ars=AR):
|
||||
|
|
@ -61,6 +61,27 @@ ASPECT_RATIO_4K = {
|
|||
"2.08": (4156, 1994),
|
||||
}
|
||||
|
||||
# S = 3686400
|
||||
ASPECT_RATIO_2K = {
|
||||
"0.38": (1176, 3136),
|
||||
"0.43": (1256, 2930),
|
||||
"0.48": (1330, 2770),
|
||||
"0.50": (1358, 2716),
|
||||
"0.53": (1398, 2640),
|
||||
"0.54": (1412, 2612),
|
||||
"0.56": (1440, 2560), # base
|
||||
"0.62": (1518, 2428),
|
||||
"0.67": (1568, 2352),
|
||||
"0.75": (1662, 2216),
|
||||
"1.00": (1920, 1920),
|
||||
"1.33": (2218, 1664),
|
||||
"1.50": (2352, 1568),
|
||||
"1.78": (2560, 1440),
|
||||
"1.89": (2638, 1396),
|
||||
"2.00": (2716, 1358),
|
||||
"2.08": (2772, 1330),
|
||||
}
|
||||
|
||||
# S = 2073600
|
||||
ASPECT_RATIO_1080P = {
|
||||
"0.38": (882, 2352),
|
||||
|
|
@ -188,6 +209,94 @@ ASPECT_RATIO_144P = {
|
|||
}
|
||||
|
||||
# from PixArt
|
||||
# S = 8294400
|
||||
ASPECT_RATIO_2880 = {
|
||||
"0.25": (1408, 5760),
|
||||
"0.26": (1408, 5568),
|
||||
"0.27": (1408, 5376),
|
||||
"0.28": (1408, 5184),
|
||||
"0.32": (1600, 4992),
|
||||
"0.33": (1600, 4800),
|
||||
"0.34": (1600, 4672),
|
||||
"0.4": (1792, 4480),
|
||||
"0.42": (1792, 4288),
|
||||
"0.47": (1920, 4096),
|
||||
"0.49": (1920, 3904),
|
||||
"0.51": (1920, 3776),
|
||||
"0.55": (2112, 3840),
|
||||
"0.59": (2112, 3584),
|
||||
"0.68": (2304, 3392),
|
||||
"0.72": (2304, 3200),
|
||||
"0.78": (2496, 3200),
|
||||
"0.83": (2496, 3008),
|
||||
"0.89": (2688, 3008),
|
||||
"0.93": (2688, 2880),
|
||||
"1.0": (2880, 2880),
|
||||
"1.07": (2880, 2688),
|
||||
"1.12": (3008, 2688),
|
||||
"1.21": (3008, 2496),
|
||||
"1.28": (3200, 2496),
|
||||
"1.39": (3200, 2304),
|
||||
"1.47": (3392, 2304),
|
||||
"1.7": (3584, 2112),
|
||||
"1.82": (3840, 2112),
|
||||
"2.03": (3904, 1920),
|
||||
"2.13": (4096, 1920),
|
||||
"2.39": (4288, 1792),
|
||||
"2.5": (4480, 1792),
|
||||
"2.92": (4672, 1600),
|
||||
"3.0": (4800, 1600),
|
||||
"3.12": (4992, 1600),
|
||||
"3.68": (5184, 1408),
|
||||
"3.82": (5376, 1408),
|
||||
"3.95": (5568, 1408),
|
||||
"4.0": (5760, 1408),
|
||||
}
|
||||
|
||||
# S = 4194304
|
||||
ASPECT_RATIO_2048 = {
|
||||
"0.25": (1024, 4096),
|
||||
"0.26": (1024, 3968),
|
||||
"0.27": (1024, 3840),
|
||||
"0.28": (1024, 3712),
|
||||
"0.32": (1152, 3584),
|
||||
"0.33": (1152, 3456),
|
||||
"0.35": (1152, 3328),
|
||||
"0.4": (1280, 3200),
|
||||
"0.42": (1280, 3072),
|
||||
"0.48": (1408, 2944),
|
||||
"0.5": (1408, 2816),
|
||||
"0.52": (1408, 2688),
|
||||
"0.57": (1536, 2688),
|
||||
"0.6": (1536, 2560),
|
||||
"0.68": (1664, 2432),
|
||||
"0.72": (1664, 2304),
|
||||
"0.78": (1792, 2304),
|
||||
"0.82": (1792, 2176),
|
||||
"0.88": (1920, 2176),
|
||||
"0.94": (1920, 2048),
|
||||
"1.0": (2048, 2048),
|
||||
"1.07": (2048, 1920),
|
||||
"1.13": (2176, 1920),
|
||||
"1.21": (2176, 1792),
|
||||
"1.29": (2304, 1792),
|
||||
"1.38": (2304, 1664),
|
||||
"1.46": (2432, 1664),
|
||||
"1.67": (2560, 1536),
|
||||
"1.75": (2688, 1536),
|
||||
"2.0": (2816, 1408),
|
||||
"2.09": (2944, 1408),
|
||||
"2.4": (3072, 1280),
|
||||
"2.5": (3200, 1280),
|
||||
"2.89": (3328, 1152),
|
||||
"3.0": (3456, 1152),
|
||||
"3.11": (3584, 1152),
|
||||
"3.62": (3712, 1024),
|
||||
"3.75": (3840, 1024),
|
||||
"3.88": (3968, 1024),
|
||||
"4.0": (4096, 1024),
|
||||
}
|
||||
|
||||
# S = 1048576
|
||||
ASPECT_RATIO_1024 = {
|
||||
"0.25": (512, 2048),
|
||||
|
|
@ -337,5 +446,8 @@ ASPECT_RATIOS = {
|
|||
"720p": (921600, ASPECT_RATIO_720P),
|
||||
"1024": (1048576, ASPECT_RATIO_1024),
|
||||
"1080p": (2073600, ASPECT_RATIO_1080P),
|
||||
"2k": (3686400, ASPECT_RATIO_2K),
|
||||
"2048": (4194304, ASPECT_RATIO_2048),
|
||||
"2880": (8294400, ASPECT_RATIO_2880),
|
||||
"4k": (8294400, ASPECT_RATIO_4K),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -229,7 +229,6 @@ class KVCompressAttention(nn.Module):
|
|||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = LlamaRMSNorm,
|
||||
enable_flashattn: bool = False,
|
||||
rope=None,
|
||||
sampling="conv",
|
||||
sr_ratio=1,
|
||||
mem_eff_attention=False,
|
||||
|
|
@ -260,11 +259,6 @@ class KVCompressAttention(nn.Module):
|
|||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.rope = False
|
||||
if rope is not None:
|
||||
self.rope = True
|
||||
self.rotary_emb = rope
|
||||
|
||||
self.mem_eff_attention = mem_eff_attention
|
||||
self.attn_half = attn_half
|
||||
|
||||
|
|
@ -294,41 +288,28 @@ class KVCompressAttention(nn.Module):
|
|||
|
||||
def forward(self, x: torch.Tensor, mask=None, HW=None, block_id=None, **kwargs) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
|
||||
new_N = N
|
||||
if HW is None:
|
||||
H = W = int(N**0.5)
|
||||
else:
|
||||
H, W = HW
|
||||
|
||||
H, W = HW
|
||||
# flash attn is not memory efficient for small sequences, this is empirical
|
||||
enable_flashattn = self.enable_flashattn and (N > B)
|
||||
qkv = self.qkv(x)
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
|
||||
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
# WARNING: this may be a bug
|
||||
if self.rope:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
qkv = self.qkv(x).reshape(B, N, 3, C)
|
||||
q, k, v = qkv.unbind(2)
|
||||
dtype = q.dtype
|
||||
# KV compression
|
||||
if self.sr_ratio > 1:
|
||||
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
|
||||
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
|
||||
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||
q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if enable_flashattn:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
x = flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
|
|
@ -343,8 +324,11 @@ class KVCompressAttention(nn.Module):
|
|||
attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
|
||||
attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float("-inf"))
|
||||
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
|
||||
|
||||
else:
|
||||
# (B, N, #heads, #dim) -> (B, #heads, N, #dim)
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
dtype = q.dtype
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1) # translate attn to float32
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from .pixart import PixArt, PixArt_XL_2
|
||||
from .pixart_sigma import PixArt_SigmaMS, PixArt_Sigma_XL_2
|
||||
from .pixart_sigma import PixArt_Sigma_XL_2
|
||||
|
|
|
|||
|
|
@ -64,9 +64,6 @@ class PixArtBlock(nn.Module):
|
|||
enable_flashattn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
qk_norm=False,
|
||||
sampling="conv",
|
||||
sr_ratio=1
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
|
@ -86,9 +83,6 @@ class PixArtBlock(nn.Module):
|
|||
num_heads=num_heads,
|
||||
qkv_bias=True,
|
||||
enable_flashattn=enable_flashattn,
|
||||
qk_norm=qk_norm,
|
||||
sr_ratio=sr_ratio,
|
||||
sampling=sampling,
|
||||
)
|
||||
self.cross_attn = self.mha_cls(hidden_size, num_heads)
|
||||
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
|
||||
|
|
@ -97,8 +91,6 @@ class PixArtBlock(nn.Module):
|
|||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
|
||||
self.sampling = sampling
|
||||
self.sr_ratio = sr_ratio
|
||||
|
||||
def forward(self, x, y, t, mask=None):
|
||||
B, N, C = x.shape
|
||||
|
|
@ -136,13 +128,11 @@ class PixArt(nn.Module):
|
|||
model_max_length=120,
|
||||
dtype=torch.float32,
|
||||
freeze=None,
|
||||
qk_norm=False,
|
||||
space_scale=1.0,
|
||||
time_scale=1.0,
|
||||
enable_flashattn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
kv_compress_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in this version."
|
||||
|
|
@ -182,15 +172,6 @@ class PixArt(nn.Module):
|
|||
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
|
||||
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
|
||||
self.kv_compress_config = kv_compress_config
|
||||
if kv_compress_config is None:
|
||||
self.kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
PixArtBlock(
|
||||
|
|
@ -200,10 +181,6 @@ class PixArt(nn.Module):
|
|||
drop_path=drop_path[i],
|
||||
enable_flashattn=enable_flashattn,
|
||||
enable_layernorm_kernel=enable_layernorm_kernel,
|
||||
qk_norm=qk_norm,
|
||||
sr_ratio=int(
|
||||
self.kv_compress_config['scale_factor']) if i in self.kv_compress_config['kv_compress_layer'] else 1,
|
||||
sampling=self.kv_compress_config['sampling'],
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
|
|
@ -216,7 +193,7 @@ class PixArt(nn.Module):
|
|||
if freeze == "text":
|
||||
self.freeze_text()
|
||||
|
||||
def forward(self, x, timestep, y, mask=None, **kwargs):
|
||||
def forward(self, x, timestep, y, mask=None):
|
||||
"""
|
||||
Forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
|
|
@ -341,24 +318,19 @@ class PixArtMS(PixArt):
|
|||
self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
|
||||
self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
|
||||
|
||||
def forward(self, x, timestep, y, mask=None, height=None, width=None, ar=None, **kwargs):
|
||||
def forward(self, x, timestep, y, mask=None, data_info=None):
|
||||
"""
|
||||
Forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
B = x.shape[0]
|
||||
x = x.to(self.dtype)
|
||||
timestep = timestep.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
|
||||
hw = torch.cat([height[:, None], width[:, None]], dim=1)
|
||||
# 2. get aspect ratio
|
||||
ar = ar.unsqueeze(1)
|
||||
|
||||
c_size = hw
|
||||
ar = ar
|
||||
c_size = data_info["hw"]
|
||||
ar = data_info["ar"]
|
||||
pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype)
|
||||
|
||||
# embedding
|
||||
|
|
|
|||
|
|
@ -31,13 +31,10 @@ from timm.models.vision_transformer import Mlp
|
|||
# from .builder import MODELS
|
||||
from opensora.acceleration.checkpoint import auto_grad_checkpoint
|
||||
from opensora.models.layers.blocks import (
|
||||
Attention,
|
||||
CaptionEmbedder,
|
||||
KVCompressAttention,
|
||||
MultiHeadCrossAttention,
|
||||
PatchEmbed3D,
|
||||
SeqParallelAttention,
|
||||
SeqParallelMultiHeadCrossAttention,
|
||||
SizeEmbedder,
|
||||
T2IFinalLayer,
|
||||
TimestepEmbedder,
|
||||
approx_gelu,
|
||||
|
|
@ -45,7 +42,6 @@ from opensora.models.layers.blocks import (
|
|||
get_2d_sincos_pos_embed,
|
||||
get_layernorm,
|
||||
t2i_modulate,
|
||||
KVCompressAttention
|
||||
)
|
||||
from opensora.registry import MODELS
|
||||
from opensora.utils.ckpt_utils import load_checkpoint
|
||||
|
|
@ -67,19 +63,16 @@ class PixArtBlock(nn.Module):
|
|||
enable_sequence_parallelism=False,
|
||||
qk_norm=False,
|
||||
sampling="conv",
|
||||
sr_ratio=1
|
||||
sr_ratio=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.enable_flashattn = enable_flashattn
|
||||
self._enable_sequence_parallelism = enable_sequence_parallelism
|
||||
assert not enable_sequence_parallelism, "Sequence parallelism is not supported in this version."
|
||||
|
||||
if enable_sequence_parallelism:
|
||||
self.attn_cls = SeqParallelAttention
|
||||
self.mha_cls = SeqParallelMultiHeadCrossAttention
|
||||
else:
|
||||
self.attn_cls = KVCompressAttention
|
||||
self.mha_cls = MultiHeadCrossAttention
|
||||
self.attn_cls = KVCompressAttention
|
||||
self.mha_cls = MultiHeadCrossAttention
|
||||
|
||||
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
|
||||
self.attn = self.attn_cls(
|
||||
|
|
@ -102,13 +95,15 @@ class PixArtBlock(nn.Module):
|
|||
self.sampling = sampling
|
||||
self.sr_ratio = sr_ratio
|
||||
|
||||
def forward(self, x, y, t, mask=None):
|
||||
def forward(self, x, y, t, hw, mask=None):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + t.reshape(B, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
|
||||
x = x + self.drop_path(
|
||||
gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=hw).reshape(B, N, C)
|
||||
)
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
|
|
@ -188,9 +183,9 @@ class PixArt_Sigma(nn.Module):
|
|||
self.kv_compress_config = kv_compress_config
|
||||
if kv_compress_config is None:
|
||||
self.kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
"sampling": None,
|
||||
"scale_factor": 1,
|
||||
"kv_compress_layer": [],
|
||||
}
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
|
|
@ -203,9 +198,12 @@ class PixArt_Sigma(nn.Module):
|
|||
enable_flashattn=enable_flashattn,
|
||||
enable_layernorm_kernel=enable_layernorm_kernel,
|
||||
qk_norm=qk_norm,
|
||||
sr_ratio=int(
|
||||
self.kv_compress_config['scale_factor']) if i in self.kv_compress_config['kv_compress_layer'] else 1,
|
||||
sampling=self.kv_compress_config['sampling'],
|
||||
sr_ratio=(
|
||||
int(self.kv_compress_config["scale_factor"])
|
||||
if i in self.kv_compress_config["kv_compress_layer"]
|
||||
else 1
|
||||
),
|
||||
sampling=self.kv_compress_config["sampling"],
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
|
|
@ -218,7 +216,7 @@ class PixArt_Sigma(nn.Module):
|
|||
if freeze == "text":
|
||||
self.freeze_text()
|
||||
|
||||
def forward(self, x, timestep, y, mask=None, **kwargs):
|
||||
def forward(self, x, timestep, y, mask=None):
|
||||
"""
|
||||
Forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
|
|
@ -228,11 +226,13 @@ class PixArt_Sigma(nn.Module):
|
|||
x = x.to(self.dtype)
|
||||
timestep = timestep.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype)
|
||||
hw = (x.shape[-2] // self.patch_size[-2], x.shape[-1] // self.patch_size[-1])
|
||||
|
||||
# embedding
|
||||
x = self.x_embedder(x) # (B, N, D)
|
||||
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
|
||||
x = x + self.pos_embed
|
||||
x = x + pos_embed.to(x.device)
|
||||
if not self.no_temporal_pos_emb:
|
||||
x = rearrange(x, "b t s d -> b s t d")
|
||||
x = x + self.pos_embed_temporal
|
||||
|
|
@ -255,7 +255,7 @@ class PixArt_Sigma(nn.Module):
|
|||
|
||||
# blocks
|
||||
for block in self.blocks:
|
||||
x = auto_grad_checkpoint(block, x, y, t0, y_lens)
|
||||
x = auto_grad_checkpoint(block, x, y, t0, hw, y_lens)
|
||||
|
||||
# final process
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
|
@ -334,88 +334,9 @@ class PixArt_Sigma(nn.Module):
|
|||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PixArt_SigmaMS(PixArt_Sigma):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3"
|
||||
self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
|
||||
self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
|
||||
|
||||
def forward(self, x, timestep, y, mask=None, height=None, width=None, ar=None, **kwargs):
|
||||
"""
|
||||
Forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
B = x.shape[0]
|
||||
x = x.to(self.dtype)
|
||||
timestep = timestep.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
|
||||
hw = torch.cat([height[:, None], width[:, None]], dim=1)
|
||||
# 2. get aspect ratio
|
||||
ar = ar.unsqueeze(1)
|
||||
|
||||
c_size = hw
|
||||
ar = ar
|
||||
pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype)
|
||||
|
||||
# embedding
|
||||
x = self.x_embedder(x) # (B, N, D)
|
||||
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
|
||||
x = x + pos_embed.to(x.device)
|
||||
if not self.no_temporal_pos_emb:
|
||||
x = rearrange(x, "b t s d -> b s t d")
|
||||
x = x + self.pos_embed_temporal
|
||||
x = rearrange(x, "b s t d -> b (t s) d")
|
||||
else:
|
||||
x = rearrange(x, "b t s d -> b (t s) d")
|
||||
|
||||
t = self.t_embedder(timestep, dtype=x.dtype) # (N, D)
|
||||
B = x.shape[0]
|
||||
csize = self.csize_embedder(c_size, B)
|
||||
ar = self.ar_embedder(ar, B)
|
||||
t = t + torch.cat([csize, ar], dim=1)
|
||||
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, 1, L, D)
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
|
||||
# blocks
|
||||
for block in self.blocks:
|
||||
x = block(x, y, t0, y_lens)
|
||||
|
||||
# final process
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
|
||||
# cast to float32 for better accuracy
|
||||
x = x.to(torch.float32)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module("PixArt-Sigma-XL/2")
|
||||
def PixArt_Sigma_XL_2(from_pretrained=None, **kwargs):
|
||||
model = PixArt_Sigma(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(model, from_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module("PixArt-SigmaMS-XL/2")
|
||||
def PixArtMS_XL_2(from_pretrained=None, **kwargs):
|
||||
model = PixArt_SigmaMS(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(model, from_pretrained)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from opensora.registry import MODELS
|
|||
|
||||
|
||||
class T5Embedder:
|
||||
available_models = ["DeepFloyd/t5-v1_1-xxl", "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"]
|
||||
available_models = ["DeepFloyd/t5-v1_1-xxl"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -47,7 +47,6 @@ class T5Embedder:
|
|||
use_offload_folder=None,
|
||||
model_max_length=120,
|
||||
local_files_only=False,
|
||||
subfolder=None
|
||||
):
|
||||
self.device = torch.device(device)
|
||||
self.torch_dtype = torch_dtype or torch.bfloat16
|
||||
|
|
@ -104,14 +103,12 @@ class T5Embedder:
|
|||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
from_pretrained,
|
||||
cache_dir=cache_dir,
|
||||
subfolder="tokenizer" if subfolder else None,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
self.model = T5EncoderModel.from_pretrained(
|
||||
from_pretrained,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
subfolder="text_encoder" if subfolder else None,
|
||||
**t5_model_kwargs,
|
||||
).eval()
|
||||
self.model_max_length = model_max_length
|
||||
|
|
@ -148,7 +145,6 @@ class T5Encoder:
|
|||
cache_dir=None,
|
||||
shardformer=False,
|
||||
local_files_only=False,
|
||||
subfolder=None,
|
||||
):
|
||||
assert from_pretrained is not None, "Please specify the path to the T5 model"
|
||||
|
||||
|
|
@ -159,7 +155,6 @@ class T5Encoder:
|
|||
cache_dir=cache_dir,
|
||||
model_max_length=model_max_length,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
self.t5.model.to(dtype=dtype)
|
||||
self.y_embedder = None
|
||||
|
|
|
|||
|
|
@ -32,10 +32,13 @@ pretrained_models = {
|
|||
"OpenSora-v1-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth",
|
||||
"OpenSora-v1-HQ-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x256x256.pth",
|
||||
"OpenSora-v1-HQ-16x512x512.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x512x512.pth",
|
||||
"PixArt-Sigma-XL-2-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth",
|
||||
"PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth",
|
||||
"PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth",
|
||||
"PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth",
|
||||
"PixArt-Sigma-XL-2-256x256.pth": hf_endpoint
|
||||
+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth",
|
||||
"PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint
|
||||
+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth",
|
||||
"PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint
|
||||
+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth",
|
||||
"PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -51,7 +54,16 @@ def reparameter(ckpt, name=None, model=None):
|
|||
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
|
||||
del ckpt["pos_embed"]
|
||||
del ckpt["temp_embed"]
|
||||
if name in ["PixArt-XL-2-256x256.pth", "PixArt-XL-2-SAM-256x256.pth", "PixArt-XL-2-512x512.pth", "PixArt-Sigma-XL-2-256x256.pth", "PixArt-Sigma-XL-2-512-MS.pth", "PixArt-Sigma-XL-2-1024-MS.pth", "PixArt-Sigma-XL-2-2K-MS.pth"]:
|
||||
if name in [
|
||||
"PixArt-XL-2-256x256.pth",
|
||||
"PixArt-XL-2-SAM-256x256.pth",
|
||||
"PixArt-XL-2-512x512.pth",
|
||||
"PixArt-XL-2-1024-MS.pth",
|
||||
"PixArt-Sigma-XL-2-256x256.pth",
|
||||
"PixArt-Sigma-XL-2-512-MS.pth",
|
||||
"PixArt-Sigma-XL-2-1024-MS.pth",
|
||||
"PixArt-Sigma-XL-2-2K-MS.pth",
|
||||
]:
|
||||
ckpt = ckpt["state_dict"]
|
||||
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
|
||||
if "pos_embed" in ckpt:
|
||||
|
|
|
|||
Loading…
Reference in a new issue