Force fp16 input to fp32 to avoid nan output in timestep_transform

This commit is contained in:
HangXu 2024-06-21 11:15:39 +03:00 committed by GitHub
parent 033c2b3c82
commit 04d2ee0182

View file

@ -15,6 +15,11 @@ def timestep_transform(
scale=1.0,
num_timesteps=1,
):
# Force fp16 input to fp32 to avoid nan output
for key in ["height", "width", "num_frames"]:
if model_kwargs[key].dtype == torch.float16:
model_kwargs[key] = model_kwargs[key].float()
t = t / num_timesteps
resolution = model_kwargs["height"] * model_kwargs["width"]
ratio_space = (resolution / base_resolution).sqrt()