diff --git a/configs/pixart/inference/1x2048MS.py b/configs/pixart/inference/1x2048MS.py index a0daca4..0f48824 100644 --- a/configs/pixart/inference/1x2048MS.py +++ b/configs/pixart/inference/1x2048MS.py @@ -1,6 +1,6 @@ num_frames = 1 fps = 1 -image_size = (2560, 1536) +# image_size = (2560, 1536) # image_size = (2048, 2048) model = dict( diff --git a/opensora/models/pixart/pixart.py b/opensora/models/pixart/pixart.py index 02f8b67..9544fcb 100644 --- a/opensora/models/pixart/pixart.py +++ b/opensora/models/pixart/pixart.py @@ -204,9 +204,11 @@ class PixArt(nn.Module): t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ - x = x.to(self.dtype) - timestep = timestep.to(self.dtype) - y = y.to(self.dtype) + dtype = self.x_embedder.proj.weight.dtype + B = x.size(0) + x = x.to(dtype) + timestep = timestep.to(dtype) + y = y.to(dtype) # embedding x = self.x_embedder(x) # (B, N, D) diff --git a/opensora/schedulers/dpms/__init__.py b/opensora/schedulers/dpms/__init__.py index df10477..111e97b 100644 --- a/opensora/schedulers/dpms/__init__.py +++ b/opensora/schedulers/dpms/__init__.py @@ -24,7 +24,8 @@ class DPM_SOLVER: mask=None, progress=True, ): - assert mask is None, "mask is not supported in dpm-solver" + if mask is not None: + print("[WARNING] mask is not supported in dpm-solver, it will be ignored") n = len(prompts) model_args = text_encoder.encode(prompts) y = model_args.pop("y")