update for pixart

This commit is contained in:
zhengzangw 2024-06-24 07:04:08 +00:00
parent 7115864314
commit 491403218d
3 changed files with 8 additions and 5 deletions

View file

@ -1,6 +1,6 @@
num_frames = 1
fps = 1
image_size = (2560, 1536)
# image_size = (2560, 1536)
# image_size = (2048, 2048)
model = dict(

View file

@ -204,9 +204,11 @@ class PixArt(nn.Module):
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
dtype = self.x_embedder.proj.weight.dtype
B = x.size(0)
x = x.to(dtype)
timestep = timestep.to(dtype)
y = y.to(dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)

View file

@ -24,7 +24,8 @@ class DPM_SOLVER:
mask=None,
progress=True,
):
assert mask is None, "mask is not supported in dpm-solver"
if mask is not None:
print("[WARNING] mask is not supported in dpm-solver, it will be ignored")
n = len(prompts)
model_args = text_encoder.encode(prompts)
y = model_args.pop("y")