mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
update for pixart
This commit is contained in:
parent
7115864314
commit
491403218d
|
|
@ -1,6 +1,6 @@
|
|||
num_frames = 1
|
||||
fps = 1
|
||||
image_size = (2560, 1536)
|
||||
# image_size = (2560, 1536)
|
||||
# image_size = (2048, 2048)
|
||||
|
||||
model = dict(
|
||||
|
|
|
|||
|
|
@ -204,9 +204,11 @@ class PixArt(nn.Module):
|
|||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
x = x.to(self.dtype)
|
||||
timestep = timestep.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
dtype = self.x_embedder.proj.weight.dtype
|
||||
B = x.size(0)
|
||||
x = x.to(dtype)
|
||||
timestep = timestep.to(dtype)
|
||||
y = y.to(dtype)
|
||||
|
||||
# embedding
|
||||
x = self.x_embedder(x) # (B, N, D)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ class DPM_SOLVER:
|
|||
mask=None,
|
||||
progress=True,
|
||||
):
|
||||
assert mask is None, "mask is not supported in dpm-solver"
|
||||
if mask is not None:
|
||||
print("[WARNING] mask is not supported in dpm-solver, it will be ignored")
|
||||
n = len(prompts)
|
||||
model_args = text_encoder.encode(prompts)
|
||||
y = model_args.pop("y")
|
||||
|
|
|
|||
Loading…
Reference in a new issue