From 04d2ee0182db1e72922923049040d94f27ea3bd9 Mon Sep 17 00:00:00 2001 From: HangXu Date: Fri, 21 Jun 2024 11:15:39 +0300 Subject: [PATCH] Force fp16 input to fp32 to avoid nan output in timestep_transform --- opensora/schedulers/rf/rectified_flow.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py index 58d7b48..8acaff5 100644 --- a/opensora/schedulers/rf/rectified_flow.py +++ b/opensora/schedulers/rf/rectified_flow.py @@ -15,6 +15,11 @@ def timestep_transform( scale=1.0, num_timesteps=1, ): + # Force fp16 input to fp32 to avoid nan output + for key in ["height", "width", "num_frames"]: + if model_kwargs[key].dtype == torch.float16: + model_kwargs[key] = model_kwargs[key].float() + t = t / num_timesteps resolution = model_kwargs["height"] * model_kwargs["width"] ratio_space = (resolution / base_resolution).sqrt()