Merge pull request #523 from BurkeHulk/hotfix/fp16_nan_output

Force fp16 input to fp32 to avoid nan output in timestep_transform
This commit is contained in:
Zheng Zangwei (Alex Zheng) 2024-06-21 18:01:17 +08:00 committed by GitHub
commit 9b668e1c4e

View file

@ -15,6 +15,11 @@ def timestep_transform(
scale=1.0,
num_timesteps=1,
):
# Force fp16 input to fp32 to avoid nan output
for key in ["height", "width", "num_frames"]:
if model_kwargs[key].dtype == torch.float16:
model_kwargs[key] = model_kwargs[key].float()
t = t / num_timesteps
resolution = model_kwargs["height"] * model_kwargs["width"]
ratio_space = (resolution / base_resolution).sqrt()