diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py index 58d7b48..8acaff5 100644 --- a/opensora/schedulers/rf/rectified_flow.py +++ b/opensora/schedulers/rf/rectified_flow.py @@ -15,6 +15,11 @@ def timestep_transform( scale=1.0, num_timesteps=1, ): + # Force fp16 input to fp32 to avoid nan output + for key in ["height", "width", "num_frames"]: + if model_kwargs[key].dtype == torch.float16: + model_kwargs[key] = model_kwargs[key].float() + t = t / num_timesteps resolution = model_kwargs["height"] * model_kwargs["width"] ratio_space = (resolution / base_resolution).sqrt()