mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
Force fp16 input to fp32 to avoid nan output in timestep_transform
This commit is contained in:
parent
033c2b3c82
commit
04d2ee0182
|
|
@ -15,6 +15,11 @@ def timestep_transform(
|
|||
scale=1.0,
|
||||
num_timesteps=1,
|
||||
):
|
||||
# Force fp16 input to fp32 to avoid nan output
|
||||
for key in ["height", "width", "num_frames"]:
|
||||
if model_kwargs[key].dtype == torch.float16:
|
||||
model_kwargs[key] = model_kwargs[key].float()
|
||||
|
||||
t = t / num_timesteps
|
||||
resolution = model_kwargs["height"] * model_kwargs["width"]
|
||||
ratio_space = (resolution / base_resolution).sqrt()
|
||||
|
|
|
|||
Loading…
Reference in a new issue