From d6a16a858e191e2794ba8365faccb59fb36df503 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Sat, 27 Apr 2024 11:43:15 +0000 Subject: [PATCH] [wip] multiple resolution --- assets/texts/t2i_sigma.txt | 10 ++ assets/texts/t2v_samples.txt | 1 - configs/opensora-v1-1/inference/sample-ref.py | 1 + configs/opensora-v1-1/inference/sample.py | 1 + configs/opensora-v1-1/train/benchmark.py | 1 + configs/opensora-v1-1/train/image.py | 1 + configs/opensora-v1-1/train/stage1.py | 1 + configs/opensora-v1-1/train/stage2.py | 1 + configs/opensora-v1-1/train/stage3.py | 1 + configs/opensora-v1-1/train/video.py | 1 + .../opensora-v1-2/inference/1x2048x2048.py | 62 --------- configs/pixart/inference/1x2048MS.py | 33 +++++ opensora/datasets/aspect.py | 118 ++++++++++++++++- opensora/models/layers/blocks.py | 44 ++---- opensora/models/pixart/__init__.py | 2 +- opensora/models/pixart/pixart.py | 36 +---- opensora/models/pixart/pixart_sigma.py | 125 ++++-------------- opensora/models/text_encoder/t5.py | 7 +- opensora/utils/ckpt_utils.py | 22 ++- 19 files changed, 226 insertions(+), 242 deletions(-) create mode 100644 assets/texts/t2i_sigma.txt delete mode 100644 configs/opensora-v1-2/inference/1x2048x2048.py create mode 100644 configs/pixart/inference/1x2048MS.py diff --git a/assets/texts/t2i_sigma.txt b/assets/texts/t2i_sigma.txt new file mode 100644 index 0000000..7866969 --- /dev/null +++ b/assets/texts/t2i_sigma.txt @@ -0,0 +1,10 @@ +Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works. +A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in. +Lego model, future rocket station, intricate details, high resolution, unreal engine, UHD +One giant, sharp, metal square mirror in the center of the frame, four young people on the foreground, background sunny palm oil planation, tropical, realistic style, photography, nostalgic, green tone, mysterious, dreamy, bright color. +Modern luxury contemporary luxury home interiors house, in the style of mimicking ruined materials, ray tracing, haunting houses, and stone, capture the essence of nature, gray and bronze, dynamic outdoor shots. +Over the shoulder game perspective, game screen of Diablo 4, Inside the gorgeous palace is the wet ground, The necromancer knelt before the king, and a horde of skeletons he summoned stood at his side, cinematic light. +A curvy timber house near a sea, designed by Zaha Hadid, represent the image of a cold, modern architecture, at night, white lighting, highly detailed. +Full body shot, a French woman, Photography, French Streets background, backlighting, rim light, Fujifilm. +Eiffel Tower was Made up of more than 2 million translucent straws to look like a cloud, with the bell tower at the top of the building, Michel installed huge foam-making machines in the forest to blow huge amounts of unpredictable wet clouds in the building's classic architecture. +A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures. diff --git a/assets/texts/t2v_samples.txt b/assets/texts/t2v_samples.txt index 6184067..7953f37 100644 --- a/assets/texts/t2v_samples.txt +++ b/assets/texts/t2v_samples.txt @@ -1,4 +1,3 @@ -A small cactus with a happy face in the Sahara desert. A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. A majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty. A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene. In the foreground, a few cars can be seen driving along a winding road that cuts through the mountains. The cars are small compared to the vastness of the landscape, emphasizing the grandeur of the surroundings. The overall style of the video is a mix of adventure and tranquility, with the hot air balloons adding a touch of whimsy to the otherwise serene mountain landscape. The video is likely shot during the day, as the lighting is bright and even, casting soft shadows on the snow-covered mountains. diff --git a/configs/opensora-v1-1/inference/sample-ref.py b/configs/opensora-v1-1/inference/sample-ref.py index 557bb70..c214dc8 100644 --- a/configs/opensora-v1-1/inference/sample-ref.py +++ b/configs/opensora-v1-1/inference/sample-ref.py @@ -33,6 +33,7 @@ model = dict( from_pretrained=None, input_sq_size=512, qk_norm=True, + qk_norm_legacy=True, enable_flashattn=True, enable_layernorm_kernel=True, ) diff --git a/configs/opensora-v1-1/inference/sample.py b/configs/opensora-v1-1/inference/sample.py index cec8073..5b12cd2 100644 --- a/configs/opensora-v1-1/inference/sample.py +++ b/configs/opensora-v1-1/inference/sample.py @@ -10,6 +10,7 @@ model = dict( from_pretrained=None, input_sq_size=512, qk_norm=True, + qk_norm_legacy=True, enable_flashattn=True, enable_layernorm_kernel=True, ) diff --git a/configs/opensora-v1-1/train/benchmark.py b/configs/opensora-v1-1/train/benchmark.py index 5310b43..0a6a01b 100644 --- a/configs/opensora-v1-1/train/benchmark.py +++ b/configs/opensora-v1-1/train/benchmark.py @@ -65,6 +65,7 @@ model = dict( from_pretrained=None, input_sq_size=512, # pretrained model is trained on 512x512 qk_norm=True, + qk_norm_legacy=True, enable_flashattn=True, enable_layernorm_kernel=True, ) diff --git a/configs/opensora-v1-1/train/image.py b/configs/opensora-v1-1/train/image.py index 45748b7..de104d9 100644 --- a/configs/opensora-v1-1/train/image.py +++ b/configs/opensora-v1-1/train/image.py @@ -29,6 +29,7 @@ model = dict( from_pretrained=None, input_sq_size=512, # pretrained model is trained on 512x512 qk_norm=True, + qk_norm_legacy=True, enable_flashattn=True, enable_layernorm_kernel=True, ) diff --git a/configs/opensora-v1-1/train/stage1.py b/configs/opensora-v1-1/train/stage1.py index 944b565..8bd9b35 100644 --- a/configs/opensora-v1-1/train/stage1.py +++ b/configs/opensora-v1-1/train/stage1.py @@ -41,6 +41,7 @@ model = dict( from_pretrained=None, input_sq_size=512, # pretrained model is trained on 512x512 qk_norm=True, + qk_norm_legacy=True, enable_flashattn=True, enable_layernorm_kernel=True, ) diff --git a/configs/opensora-v1-1/train/stage2.py b/configs/opensora-v1-1/train/stage2.py index fb7e6d5..66572d5 100644 --- a/configs/opensora-v1-1/train/stage2.py +++ b/configs/opensora-v1-1/train/stage2.py @@ -43,6 +43,7 @@ model = dict( from_pretrained=None, input_sq_size=512, # pretrained model is trained on 512x512 qk_norm=True, + qk_norm_legacy=True, enable_flashattn=True, enable_layernorm_kernel=True, ) diff --git a/configs/opensora-v1-1/train/stage3.py b/configs/opensora-v1-1/train/stage3.py index 8485762..db4d42b 100644 --- a/configs/opensora-v1-1/train/stage3.py +++ b/configs/opensora-v1-1/train/stage3.py @@ -43,6 +43,7 @@ model = dict( from_pretrained=None, input_sq_size=512, # pretrained model is trained on 512x512 qk_norm=True, + qk_norm_legacy=True, enable_flashattn=True, enable_layernorm_kernel=True, ) diff --git a/configs/opensora-v1-1/train/video.py b/configs/opensora-v1-1/train/video.py index ef574f2..2912a40 100644 --- a/configs/opensora-v1-1/train/video.py +++ b/configs/opensora-v1-1/train/video.py @@ -31,6 +31,7 @@ model = dict( from_pretrained=None, input_sq_size=512, # pretrained model is trained on 512x512 qk_norm=True, + qk_norm_legacy=True, enable_flashattn=True, enable_layernorm_kernel=True, ) diff --git a/configs/opensora-v1-2/inference/1x2048x2048.py b/configs/opensora-v1-2/inference/1x2048x2048.py deleted file mode 100644 index 62a2e1d..0000000 --- a/configs/opensora-v1-2/inference/1x2048x2048.py +++ /dev/null @@ -1,62 +0,0 @@ -num_frames = 1 -fps = 1 -image_size = (2048, 2048) -multi_resolution = "STDiT2" - - -# Define model -# model = dict( -# type="STDiT2-XL/2", -# from_pretrained="/home/zhouyukun/data/models/PixArt-Sigma/PixArt-Sigma-XL-2-256x256.pth", -# input_sq_size=512, -# qk_norm=True, -# enable_flashattn=True, -# enable_layernorm_kernel=True, -# ) - -model = dict( - type="PixArt-Sigma-XL/2", - space_scale=4, - no_temporal_pos_emb=True, - from_pretrained="PixArt-Sigma-XL-2-2K-MS.pth", -) - - -vae = dict( - type="VideoAutoencoderKL", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - subfolder="vae" -) - -text_encoder = dict( - type="t5", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - model_max_length=300, - cache_dir=None, - subfolder=True -) - - -scheduler = dict( - type="iddpm", - num_sampling_steps=250, - cfg_scale=7, - cfg_channel=3, # or None -) - -# scheduler = dict( -# type="dpm-solver", -# num_sampling_steps=50, -# cfg_scale=4.0, -# ) - -dtype = "bf16" - -# Condition -prompt_path = "./assets/texts/t2v_samples.txt" -prompt = None # prompt has higher priority than prompt_path - -# Others -batch_size = 1 -seed = 42 -save_dir = "./samples/samples/" diff --git a/configs/pixart/inference/1x2048MS.py b/configs/pixart/inference/1x2048MS.py new file mode 100644 index 0000000..41849e9 --- /dev/null +++ b/configs/pixart/inference/1x2048MS.py @@ -0,0 +1,33 @@ +num_frames = 1 +fps = 1 +image_size = (2688, 1408) +# image_size = (2048, 2048) + +model = dict( + type="PixArt-Sigma-XL/2", + space_scale=4, + no_temporal_pos_emb=True, + from_pretrained="PixArt-Sigma-XL-2-2K-MS.pth", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", + subfolder="vae", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=300, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=14, + cfg_scale=4.5, +) +dtype = "bf16" + +# Others +batch_size = 1 +seed = 42 +prompt_path = "./assets/texts/t2i_sigma.txt" +save_dir = "./samples/samples/" diff --git a/opensora/datasets/aspect.py b/opensora/datasets/aspect.py index 57e33f2..a49fbcf 100644 --- a/opensora/datasets/aspect.py +++ b/opensora/datasets/aspect.py @@ -13,7 +13,7 @@ def get_h_w(a, ts, eps=1e-4): return h, w -AR = [ +AR = ( 3 / 8, 9 / 21, 0.48, @@ -31,8 +31,8 @@ AR = [ 17 / 9, 2 / 1, 1 / 0.48, -] -ARV = [0.375, 0.43, 0.48, 0.50, 0.53, 0.54, 0.56, 0.62, 0.67, 0.75, 1.0, 1.33, 1.50, 1.78, 1.89, 2.0, 2.08] +) +ARV = (0.375, 0.43, 0.48, 0.50, 0.53, 0.54, 0.56, 0.62, 0.67, 0.75, 1, 1.33, 1.50, 1.78, 1.89, 2, 2.08) def get_aspect_ratios_dict(ts=360 * 640, ars=AR): @@ -61,6 +61,27 @@ ASPECT_RATIO_4K = { "2.08": (4156, 1994), } +# S = 3686400 +ASPECT_RATIO_2K = { + "0.38": (1176, 3136), + "0.43": (1256, 2930), + "0.48": (1330, 2770), + "0.50": (1358, 2716), + "0.53": (1398, 2640), + "0.54": (1412, 2612), + "0.56": (1440, 2560), # base + "0.62": (1518, 2428), + "0.67": (1568, 2352), + "0.75": (1662, 2216), + "1.00": (1920, 1920), + "1.33": (2218, 1664), + "1.50": (2352, 1568), + "1.78": (2560, 1440), + "1.89": (2638, 1396), + "2.00": (2716, 1358), + "2.08": (2772, 1330), +} + # S = 2073600 ASPECT_RATIO_1080P = { "0.38": (882, 2352), @@ -188,6 +209,94 @@ ASPECT_RATIO_144P = { } # from PixArt +# S = 8294400 +ASPECT_RATIO_2880 = { + "0.25": (1408, 5760), + "0.26": (1408, 5568), + "0.27": (1408, 5376), + "0.28": (1408, 5184), + "0.32": (1600, 4992), + "0.33": (1600, 4800), + "0.34": (1600, 4672), + "0.4": (1792, 4480), + "0.42": (1792, 4288), + "0.47": (1920, 4096), + "0.49": (1920, 3904), + "0.51": (1920, 3776), + "0.55": (2112, 3840), + "0.59": (2112, 3584), + "0.68": (2304, 3392), + "0.72": (2304, 3200), + "0.78": (2496, 3200), + "0.83": (2496, 3008), + "0.89": (2688, 3008), + "0.93": (2688, 2880), + "1.0": (2880, 2880), + "1.07": (2880, 2688), + "1.12": (3008, 2688), + "1.21": (3008, 2496), + "1.28": (3200, 2496), + "1.39": (3200, 2304), + "1.47": (3392, 2304), + "1.7": (3584, 2112), + "1.82": (3840, 2112), + "2.03": (3904, 1920), + "2.13": (4096, 1920), + "2.39": (4288, 1792), + "2.5": (4480, 1792), + "2.92": (4672, 1600), + "3.0": (4800, 1600), + "3.12": (4992, 1600), + "3.68": (5184, 1408), + "3.82": (5376, 1408), + "3.95": (5568, 1408), + "4.0": (5760, 1408), +} + +# S = 4194304 +ASPECT_RATIO_2048 = { + "0.25": (1024, 4096), + "0.26": (1024, 3968), + "0.27": (1024, 3840), + "0.28": (1024, 3712), + "0.32": (1152, 3584), + "0.33": (1152, 3456), + "0.35": (1152, 3328), + "0.4": (1280, 3200), + "0.42": (1280, 3072), + "0.48": (1408, 2944), + "0.5": (1408, 2816), + "0.52": (1408, 2688), + "0.57": (1536, 2688), + "0.6": (1536, 2560), + "0.68": (1664, 2432), + "0.72": (1664, 2304), + "0.78": (1792, 2304), + "0.82": (1792, 2176), + "0.88": (1920, 2176), + "0.94": (1920, 2048), + "1.0": (2048, 2048), + "1.07": (2048, 1920), + "1.13": (2176, 1920), + "1.21": (2176, 1792), + "1.29": (2304, 1792), + "1.38": (2304, 1664), + "1.46": (2432, 1664), + "1.67": (2560, 1536), + "1.75": (2688, 1536), + "2.0": (2816, 1408), + "2.09": (2944, 1408), + "2.4": (3072, 1280), + "2.5": (3200, 1280), + "2.89": (3328, 1152), + "3.0": (3456, 1152), + "3.11": (3584, 1152), + "3.62": (3712, 1024), + "3.75": (3840, 1024), + "3.88": (3968, 1024), + "4.0": (4096, 1024), +} + # S = 1048576 ASPECT_RATIO_1024 = { "0.25": (512, 2048), @@ -337,5 +446,8 @@ ASPECT_RATIOS = { "720p": (921600, ASPECT_RATIO_720P), "1024": (1048576, ASPECT_RATIO_1024), "1080p": (2073600, ASPECT_RATIO_1080P), + "2k": (3686400, ASPECT_RATIO_2K), + "2048": (4194304, ASPECT_RATIO_2048), + "2880": (8294400, ASPECT_RATIO_2880), "4k": (8294400, ASPECT_RATIO_4K), } diff --git a/opensora/models/layers/blocks.py b/opensora/models/layers/blocks.py index 705a4aa..00fcb7c 100644 --- a/opensora/models/layers/blocks.py +++ b/opensora/models/layers/blocks.py @@ -229,7 +229,6 @@ class KVCompressAttention(nn.Module): proj_drop: float = 0.0, norm_layer: nn.Module = LlamaRMSNorm, enable_flashattn: bool = False, - rope=None, sampling="conv", sr_ratio=1, mem_eff_attention=False, @@ -260,11 +259,6 @@ class KVCompressAttention(nn.Module): self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - self.rope = False - if rope is not None: - self.rope = True - self.rotary_emb = rope - self.mem_eff_attention = mem_eff_attention self.attn_half = attn_half @@ -294,41 +288,28 @@ class KVCompressAttention(nn.Module): def forward(self, x: torch.Tensor, mask=None, HW=None, block_id=None, **kwargs) -> torch.Tensor: B, N, C = x.shape - new_N = N - if HW is None: - H = W = int(N**0.5) - else: - H, W = HW - + H, W = HW # flash attn is not memory efficient for small sequences, this is empirical enable_flashattn = self.enable_flashattn and (N > B) - qkv = self.qkv(x) - qkv_shape = (B, N, 3, self.num_heads, self.head_dim) - - qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - # WARNING: this may be a bug - if self.rope: - q = self.rotary_emb(q) - k = self.rotary_emb(k) - q, k = self.q_norm(q), self.k_norm(k) + qkv = self.qkv(x).reshape(B, N, 3, C) + q, k, v = qkv.unbind(2) + dtype = q.dtype + # KV compression if self.sr_ratio > 1: k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling) v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling) - q = q.reshape(B, N, self.num_heads, C // self.num_heads) - k = k.reshape(B, new_N, self.num_heads, C // self.num_heads) - v = v.reshape(B, new_N, self.num_heads, C // self.num_heads) + q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype) + k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + + q, k = self.q_norm(q), self.k_norm(k) if enable_flashattn: from flash_attn import flash_attn_func - # (B, #heads, N, #dim) -> (B, N, #heads, #dim) - q = q.permute(0, 2, 1, 3) - k = k.permute(0, 2, 1, 3) - v = v.permute(0, 2, 1, 3) x = flash_attn_func( q, k, @@ -343,8 +324,11 @@ class KVCompressAttention(nn.Module): attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float("-inf")) x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) - else: + # (B, N, #heads, #dim) -> (B, #heads, N, #dim) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) dtype = q.dtype q = q * self.scale attn = q @ k.transpose(-2, -1) # translate attn to float32 diff --git a/opensora/models/pixart/__init__.py b/opensora/models/pixart/__init__.py index edcea17..b83e18e 100644 --- a/opensora/models/pixart/__init__.py +++ b/opensora/models/pixart/__init__.py @@ -1,2 +1,2 @@ from .pixart import PixArt, PixArt_XL_2 -from .pixart_sigma import PixArt_SigmaMS, PixArt_Sigma_XL_2 +from .pixart_sigma import PixArt_Sigma_XL_2 diff --git a/opensora/models/pixart/pixart.py b/opensora/models/pixart/pixart.py index 2a9d177..421f836 100644 --- a/opensora/models/pixart/pixart.py +++ b/opensora/models/pixart/pixart.py @@ -64,9 +64,6 @@ class PixArtBlock(nn.Module): enable_flashattn=False, enable_layernorm_kernel=False, enable_sequence_parallelism=False, - qk_norm=False, - sampling="conv", - sr_ratio=1 ): super().__init__() self.hidden_size = hidden_size @@ -86,9 +83,6 @@ class PixArtBlock(nn.Module): num_heads=num_heads, qkv_bias=True, enable_flashattn=enable_flashattn, - qk_norm=qk_norm, - sr_ratio=sr_ratio, - sampling=sampling, ) self.cross_attn = self.mha_cls(hidden_size, num_heads) self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) @@ -97,8 +91,6 @@ class PixArtBlock(nn.Module): ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) - self.sampling = sampling - self.sr_ratio = sr_ratio def forward(self, x, y, t, mask=None): B, N, C = x.shape @@ -136,13 +128,11 @@ class PixArt(nn.Module): model_max_length=120, dtype=torch.float32, freeze=None, - qk_norm=False, space_scale=1.0, time_scale=1.0, enable_flashattn=False, enable_layernorm_kernel=False, enable_sequence_parallelism=False, - kv_compress_config=None, ): super().__init__() assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in this version." @@ -182,15 +172,6 @@ class PixArt(nn.Module): self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule - - self.kv_compress_config = kv_compress_config - if kv_compress_config is None: - self.kv_compress_config = { - 'sampling': None, - 'scale_factor': 1, - 'kv_compress_layer': [], - } - self.blocks = nn.ModuleList( [ PixArtBlock( @@ -200,10 +181,6 @@ class PixArt(nn.Module): drop_path=drop_path[i], enable_flashattn=enable_flashattn, enable_layernorm_kernel=enable_layernorm_kernel, - qk_norm=qk_norm, - sr_ratio=int( - self.kv_compress_config['scale_factor']) if i in self.kv_compress_config['kv_compress_layer'] else 1, - sampling=self.kv_compress_config['sampling'], ) for i in range(depth) ] @@ -216,7 +193,7 @@ class PixArt(nn.Module): if freeze == "text": self.freeze_text() - def forward(self, x, timestep, y, mask=None, **kwargs): + def forward(self, x, timestep, y, mask=None): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) @@ -341,24 +318,19 @@ class PixArtMS(PixArt): self.csize_embedder = SizeEmbedder(self.hidden_size // 3) self.ar_embedder = SizeEmbedder(self.hidden_size // 3) - def forward(self, x, timestep, y, mask=None, height=None, width=None, ar=None, **kwargs): + def forward(self, x, timestep, y, mask=None, data_info=None): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ - B = x.shape[0] x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = y.to(self.dtype) - hw = torch.cat([height[:, None], width[:, None]], dim=1) - # 2. get aspect ratio - ar = ar.unsqueeze(1) - - c_size = hw - ar = ar + c_size = data_info["hw"] + ar = data_info["ar"] pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype) # embedding diff --git a/opensora/models/pixart/pixart_sigma.py b/opensora/models/pixart/pixart_sigma.py index 2f6449b..e815aa9 100644 --- a/opensora/models/pixart/pixart_sigma.py +++ b/opensora/models/pixart/pixart_sigma.py @@ -31,13 +31,10 @@ from timm.models.vision_transformer import Mlp # from .builder import MODELS from opensora.acceleration.checkpoint import auto_grad_checkpoint from opensora.models.layers.blocks import ( - Attention, CaptionEmbedder, + KVCompressAttention, MultiHeadCrossAttention, PatchEmbed3D, - SeqParallelAttention, - SeqParallelMultiHeadCrossAttention, - SizeEmbedder, T2IFinalLayer, TimestepEmbedder, approx_gelu, @@ -45,7 +42,6 @@ from opensora.models.layers.blocks import ( get_2d_sincos_pos_embed, get_layernorm, t2i_modulate, - KVCompressAttention ) from opensora.registry import MODELS from opensora.utils.ckpt_utils import load_checkpoint @@ -67,19 +63,16 @@ class PixArtBlock(nn.Module): enable_sequence_parallelism=False, qk_norm=False, sampling="conv", - sr_ratio=1 + sr_ratio=1, ): super().__init__() self.hidden_size = hidden_size self.enable_flashattn = enable_flashattn self._enable_sequence_parallelism = enable_sequence_parallelism + assert not enable_sequence_parallelism, "Sequence parallelism is not supported in this version." - if enable_sequence_parallelism: - self.attn_cls = SeqParallelAttention - self.mha_cls = SeqParallelMultiHeadCrossAttention - else: - self.attn_cls = KVCompressAttention - self.mha_cls = MultiHeadCrossAttention + self.attn_cls = KVCompressAttention + self.mha_cls = MultiHeadCrossAttention self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) self.attn = self.attn_cls( @@ -102,13 +95,15 @@ class PixArtBlock(nn.Module): self.sampling = sampling self.sr_ratio = sr_ratio - def forward(self, x, y, t, mask=None): + def forward(self, x, y, t, hw, mask=None): B, N, C = x.shape shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + t.reshape(B, 6, -1) ).chunk(6, dim=1) - x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) + x = x + self.drop_path( + gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=hw).reshape(B, N, C) + ) x = x + self.cross_attn(x, y, mask) x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) @@ -188,9 +183,9 @@ class PixArt_Sigma(nn.Module): self.kv_compress_config = kv_compress_config if kv_compress_config is None: self.kv_compress_config = { - 'sampling': None, - 'scale_factor': 1, - 'kv_compress_layer': [], + "sampling": None, + "scale_factor": 1, + "kv_compress_layer": [], } self.blocks = nn.ModuleList( @@ -203,9 +198,12 @@ class PixArt_Sigma(nn.Module): enable_flashattn=enable_flashattn, enable_layernorm_kernel=enable_layernorm_kernel, qk_norm=qk_norm, - sr_ratio=int( - self.kv_compress_config['scale_factor']) if i in self.kv_compress_config['kv_compress_layer'] else 1, - sampling=self.kv_compress_config['sampling'], + sr_ratio=( + int(self.kv_compress_config["scale_factor"]) + if i in self.kv_compress_config["kv_compress_layer"] + else 1 + ), + sampling=self.kv_compress_config["sampling"], ) for i in range(depth) ] @@ -218,7 +216,7 @@ class PixArt_Sigma(nn.Module): if freeze == "text": self.freeze_text() - def forward(self, x, timestep, y, mask=None, **kwargs): + def forward(self, x, timestep, y, mask=None): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) @@ -228,11 +226,13 @@ class PixArt_Sigma(nn.Module): x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = y.to(self.dtype) + pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype) + hw = (x.shape[-2] // self.patch_size[-2], x.shape[-1] // self.patch_size[-1]) # embedding x = self.x_embedder(x) # (B, N, D) x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) - x = x + self.pos_embed + x = x + pos_embed.to(x.device) if not self.no_temporal_pos_emb: x = rearrange(x, "b t s d -> b s t d") x = x + self.pos_embed_temporal @@ -255,7 +255,7 @@ class PixArt_Sigma(nn.Module): # blocks for block in self.blocks: - x = auto_grad_checkpoint(block, x, y, t0, y_lens) + x = auto_grad_checkpoint(block, x, y, t0, hw, y_lens) # final process x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) @@ -334,88 +334,9 @@ class PixArt_Sigma(nn.Module): nn.init.constant_(self.final_layer.linear.bias, 0) -@MODELS.register_module() -class PixArt_SigmaMS(PixArt_Sigma): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3" - self.csize_embedder = SizeEmbedder(self.hidden_size // 3) - self.ar_embedder = SizeEmbedder(self.hidden_size // 3) - - def forward(self, x, timestep, y, mask=None, height=None, width=None, ar=None, **kwargs): - """ - Forward pass of PixArt. - x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) - t: (N,) tensor of diffusion timesteps - y: (N, 1, 120, C) tensor of class labels - """ - B = x.shape[0] - x = x.to(self.dtype) - timestep = timestep.to(self.dtype) - y = y.to(self.dtype) - - hw = torch.cat([height[:, None], width[:, None]], dim=1) - # 2. get aspect ratio - ar = ar.unsqueeze(1) - - c_size = hw - ar = ar - pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype) - - # embedding - x = self.x_embedder(x) # (B, N, D) - x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) - x = x + pos_embed.to(x.device) - if not self.no_temporal_pos_emb: - x = rearrange(x, "b t s d -> b s t d") - x = x + self.pos_embed_temporal - x = rearrange(x, "b s t d -> b (t s) d") - else: - x = rearrange(x, "b t s d -> b (t s) d") - - t = self.t_embedder(timestep, dtype=x.dtype) # (N, D) - B = x.shape[0] - csize = self.csize_embedder(c_size, B) - ar = self.ar_embedder(ar, B) - t = t + torch.cat([csize, ar], dim=1) - - t0 = self.t_block(t) - y = self.y_embedder(y, self.training) # (N, 1, L, D) - if mask is not None: - if mask.shape[0] != y.shape[0]: - mask = mask.repeat(y.shape[0] // mask.shape[0], 1) - mask = mask.squeeze(1).squeeze(1) - y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) - y_lens = mask.sum(dim=1).tolist() - else: - y_lens = [y.shape[2]] * y.shape[0] - y = y.squeeze(1).view(1, -1, x.shape[-1]) - - # blocks - for block in self.blocks: - x = block(x, y, t0, y_lens) - - # final process - x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) - x = self.unpatchify(x) # (N, out_channels, H, W) - - # cast to float32 for better accuracy - x = x.to(torch.float32) - return x - - @MODELS.register_module("PixArt-Sigma-XL/2") def PixArt_Sigma_XL_2(from_pretrained=None, **kwargs): model = PixArt_Sigma(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) if from_pretrained is not None: load_checkpoint(model, from_pretrained) return model - - -@MODELS.register_module("PixArt-SigmaMS-XL/2") -def PixArtMS_XL_2(from_pretrained=None, **kwargs): - model = PixArt_SigmaMS(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) - if from_pretrained is not None: - load_checkpoint(model, from_pretrained) - return model diff --git a/opensora/models/text_encoder/t5.py b/opensora/models/text_encoder/t5.py index 6d978d2..aaf2ecf 100644 --- a/opensora/models/text_encoder/t5.py +++ b/opensora/models/text_encoder/t5.py @@ -32,7 +32,7 @@ from opensora.registry import MODELS class T5Embedder: - available_models = ["DeepFloyd/t5-v1_1-xxl", "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"] + available_models = ["DeepFloyd/t5-v1_1-xxl"] def __init__( self, @@ -47,7 +47,6 @@ class T5Embedder: use_offload_folder=None, model_max_length=120, local_files_only=False, - subfolder=None ): self.device = torch.device(device) self.torch_dtype = torch_dtype or torch.bfloat16 @@ -104,14 +103,12 @@ class T5Embedder: self.tokenizer = AutoTokenizer.from_pretrained( from_pretrained, cache_dir=cache_dir, - subfolder="tokenizer" if subfolder else None, local_files_only=local_files_only, ) self.model = T5EncoderModel.from_pretrained( from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only, - subfolder="text_encoder" if subfolder else None, **t5_model_kwargs, ).eval() self.model_max_length = model_max_length @@ -148,7 +145,6 @@ class T5Encoder: cache_dir=None, shardformer=False, local_files_only=False, - subfolder=None, ): assert from_pretrained is not None, "Please specify the path to the T5 model" @@ -159,7 +155,6 @@ class T5Encoder: cache_dir=cache_dir, model_max_length=model_max_length, local_files_only=local_files_only, - subfolder=subfolder, ) self.t5.model.to(dtype=dtype) self.y_embedder = None diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index 8d691bf..157e9cb 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -32,10 +32,13 @@ pretrained_models = { "OpenSora-v1-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth", "OpenSora-v1-HQ-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x256x256.pth", "OpenSora-v1-HQ-16x512x512.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x512x512.pth", - "PixArt-Sigma-XL-2-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth", - "PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth", - "PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth", - "PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint+ "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth", + "PixArt-Sigma-XL-2-256x256.pth": hf_endpoint + + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth", + "PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint + + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth", + "PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint + + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth", + "PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth", } @@ -51,7 +54,16 @@ def reparameter(ckpt, name=None, model=None): ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) del ckpt["pos_embed"] del ckpt["temp_embed"] - if name in ["PixArt-XL-2-256x256.pth", "PixArt-XL-2-SAM-256x256.pth", "PixArt-XL-2-512x512.pth", "PixArt-Sigma-XL-2-256x256.pth", "PixArt-Sigma-XL-2-512-MS.pth", "PixArt-Sigma-XL-2-1024-MS.pth", "PixArt-Sigma-XL-2-2K-MS.pth"]: + if name in [ + "PixArt-XL-2-256x256.pth", + "PixArt-XL-2-SAM-256x256.pth", + "PixArt-XL-2-512x512.pth", + "PixArt-XL-2-1024-MS.pth", + "PixArt-Sigma-XL-2-256x256.pth", + "PixArt-Sigma-XL-2-512-MS.pth", + "PixArt-Sigma-XL-2-1024-MS.pth", + "PixArt-Sigma-XL-2-2K-MS.pth", + ]: ckpt = ckpt["state_dict"] ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) if "pos_embed" in ckpt: