[fix] pixart sampling

This commit is contained in:
zhengzangw 2024-06-26 07:00:24 +00:00
parent 2251c4ad47
commit 4b2b47b34d
6 changed files with 20 additions and 8 deletions

View file

@ -16,6 +16,7 @@ vae = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
scaling_factor=0.13025,
)
text_encoder = dict(
type="t5",

View file

@ -465,7 +465,10 @@ def get_num_pixels(name):
def get_image_size(resolution, ar_ratio):
ar_key = ASPECT_RATIO_MAP[ar_ratio]
if ar_ratio in ASPECT_RATIO_MAP:
ar_key = ASPECT_RATIO_MAP[ar_ratio]
else:
ar_key = ar_ratio
rs_dict = ASPECT_RATIOS[resolution][1]
assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}"
return rs_dict[ar_key]

View file

@ -197,7 +197,7 @@ class PixArt(nn.Module):
if freeze == "text":
self.freeze_text()
def forward(self, x, timestep, y, mask=None):
def forward(self, x, timestep, y, mask=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)

View file

@ -13,7 +13,13 @@ from opensora.utils.ckpt_utils import load_checkpoint
@MODELS.register_module()
class VideoAutoencoderKL(nn.Module):
def __init__(
self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None
self,
from_pretrained=None,
micro_batch_size=None,
cache_dir=None,
local_files_only=False,
subfolder=None,
scaling_factor=0.18215,
):
super().__init__()
self.module = AutoencoderKL.from_pretrained(
@ -25,6 +31,7 @@ class VideoAutoencoderKL(nn.Module):
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)
self.micro_batch_size = micro_batch_size
self.scaling_factor = scaling_factor
def encode(self, x):
# x: (B, C, T, H, W)
@ -32,14 +39,14 @@ class VideoAutoencoderKL(nn.Module):
x = rearrange(x, "B C T H W -> (B T) C H W")
if self.micro_batch_size is None:
x = self.module.encode(x).latent_dist.sample().mul_(0.18215)
x = self.module.encode(x).latent_dist.sample().mul_(self.scaling_factor)
else:
# NOTE: cannot be used for training
bs = self.micro_batch_size
x_out = []
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215)
x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(self.scaling_factor)
x_out.append(x_bs)
x = torch.cat(x_out, dim=0)
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
@ -50,14 +57,14 @@ class VideoAutoencoderKL(nn.Module):
B = x.shape[0]
x = rearrange(x, "B C T H W -> (B T) C H W")
if self.micro_batch_size is None:
x = self.module.decode(x / 0.18215).sample
x = self.module.decode(x / self.scaling_factor).sample
else:
# NOTE: cannot be used for training
bs = self.micro_batch_size
x_out = []
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.decode(x_bs / 0.18215).sample
x_bs = self.module.decode(x_bs / self.scaling_factor).sample
x_out.append(x_bs)
x = torch.cat(x_out, dim=0)
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)

View file

@ -1419,7 +1419,7 @@ class DPM_Solver:
for step in progress_fn(range(order, steps + 1)):
t = timesteps[step]
# We only use lower order for steps < 10
if lower_order_final and steps < 10:
if lower_order_final: # recommended by Shuchen Xue
step_order = min(order, steps + 1 - step)
else:
step_order = order

View file

@ -260,6 +260,7 @@ def main():
)
# == sampling ==
torch.manual_seed(1024)
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)
samples = scheduler.sample(