mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
Merge pull request #523 from BurkeHulk/hotfix/fp16_nan_output
Force fp16 input to fp32 to avoid nan output in timestep_transform
This commit is contained in:
commit
9b668e1c4e
|
|
@ -15,6 +15,11 @@ def timestep_transform(
|
|||
scale=1.0,
|
||||
num_timesteps=1,
|
||||
):
|
||||
# Force fp16 input to fp32 to avoid nan output
|
||||
for key in ["height", "width", "num_frames"]:
|
||||
if model_kwargs[key].dtype == torch.float16:
|
||||
model_kwargs[key] = model_kwargs[key].float()
|
||||
|
||||
t = t / num_timesteps
|
||||
resolution = model_kwargs["height"] * model_kwargs["width"]
|
||||
ratio_space = (resolution / base_resolution).sqrt()
|
||||
|
|
|
|||
Loading…
Reference in a new issue