mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-02-22 21:43:19 +01:00
[fix] pixart sampling
This commit is contained in:
parent
2251c4ad47
commit
4b2b47b34d
|
|
@ -16,6 +16,7 @@ vae = dict(
|
|||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
scaling_factor=0.13025,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
|
|
|
|||
|
|
@ -465,7 +465,10 @@ def get_num_pixels(name):
|
|||
|
||||
|
||||
def get_image_size(resolution, ar_ratio):
|
||||
ar_key = ASPECT_RATIO_MAP[ar_ratio]
|
||||
if ar_ratio in ASPECT_RATIO_MAP:
|
||||
ar_key = ASPECT_RATIO_MAP[ar_ratio]
|
||||
else:
|
||||
ar_key = ar_ratio
|
||||
rs_dict = ASPECT_RATIOS[resolution][1]
|
||||
assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}"
|
||||
return rs_dict[ar_key]
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ class PixArt(nn.Module):
|
|||
if freeze == "text":
|
||||
self.freeze_text()
|
||||
|
||||
def forward(self, x, timestep, y, mask=None):
|
||||
def forward(self, x, timestep, y, mask=None, **kwargs):
|
||||
"""
|
||||
Forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,13 @@ from opensora.utils.ckpt_utils import load_checkpoint
|
|||
@MODELS.register_module()
|
||||
class VideoAutoencoderKL(nn.Module):
|
||||
def __init__(
|
||||
self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None
|
||||
self,
|
||||
from_pretrained=None,
|
||||
micro_batch_size=None,
|
||||
cache_dir=None,
|
||||
local_files_only=False,
|
||||
subfolder=None,
|
||||
scaling_factor=0.18215,
|
||||
):
|
||||
super().__init__()
|
||||
self.module = AutoencoderKL.from_pretrained(
|
||||
|
|
@ -25,6 +31,7 @@ class VideoAutoencoderKL(nn.Module):
|
|||
self.out_channels = self.module.config.latent_channels
|
||||
self.patch_size = (1, 8, 8)
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
def encode(self, x):
|
||||
# x: (B, C, T, H, W)
|
||||
|
|
@ -32,14 +39,14 @@ class VideoAutoencoderKL(nn.Module):
|
|||
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||
|
||||
if self.micro_batch_size is None:
|
||||
x = self.module.encode(x).latent_dist.sample().mul_(0.18215)
|
||||
x = self.module.encode(x).latent_dist.sample().mul_(self.scaling_factor)
|
||||
else:
|
||||
# NOTE: cannot be used for training
|
||||
bs = self.micro_batch_size
|
||||
x_out = []
|
||||
for i in range(0, x.shape[0], bs):
|
||||
x_bs = x[i : i + bs]
|
||||
x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215)
|
||||
x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(self.scaling_factor)
|
||||
x_out.append(x_bs)
|
||||
x = torch.cat(x_out, dim=0)
|
||||
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
|
||||
|
|
@ -50,14 +57,14 @@ class VideoAutoencoderKL(nn.Module):
|
|||
B = x.shape[0]
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||
if self.micro_batch_size is None:
|
||||
x = self.module.decode(x / 0.18215).sample
|
||||
x = self.module.decode(x / self.scaling_factor).sample
|
||||
else:
|
||||
# NOTE: cannot be used for training
|
||||
bs = self.micro_batch_size
|
||||
x_out = []
|
||||
for i in range(0, x.shape[0], bs):
|
||||
x_bs = x[i : i + bs]
|
||||
x_bs = self.module.decode(x_bs / 0.18215).sample
|
||||
x_bs = self.module.decode(x_bs / self.scaling_factor).sample
|
||||
x_out.append(x_bs)
|
||||
x = torch.cat(x_out, dim=0)
|
||||
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
|
||||
|
|
|
|||
|
|
@ -1419,7 +1419,7 @@ class DPM_Solver:
|
|||
for step in progress_fn(range(order, steps + 1)):
|
||||
t = timesteps[step]
|
||||
# We only use lower order for steps < 10
|
||||
if lower_order_final and steps < 10:
|
||||
if lower_order_final: # recommended by Shuchen Xue
|
||||
step_order = min(order, steps + 1 - step)
|
||||
else:
|
||||
step_order = order
|
||||
|
|
|
|||
|
|
@ -260,6 +260,7 @@ def main():
|
|||
)
|
||||
|
||||
# == sampling ==
|
||||
torch.manual_seed(1024)
|
||||
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
|
||||
masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)
|
||||
samples = scheduler.sample(
|
||||
|
|
|
|||
Loading…
Reference in a new issue