From df5668cdf1b6603c43f34d91008f2cc3c734aa1a Mon Sep 17 00:00:00 2001 From: "Gao, Ruiyuan" <905370712@qq.com> Date: Thu, 20 Feb 2025 16:40:47 +0800 Subject: [PATCH] fix bug at mha, MaskGenerator; improve ckpt_utils.py (#609) * fix bug at mha in blocks.py * fix bug in MaskGenerator * align logging style in ckpt_utils.py --- opensora/models/layers/blocks.py | 10 ++++++++++ opensora/utils/ckpt_utils.py | 4 ++-- opensora/utils/train_utils.py | 6 +++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/opensora/models/layers/blocks.py b/opensora/models/layers/blocks.py index 40e6abb..e28ef34 100644 --- a/opensora/models/layers/blocks.py +++ b/opensora/models/layers/blocks.py @@ -467,6 +467,11 @@ class MultiHeadCrossAttention(nn.Module): # query/value: img tokens; key: condition; mask: if padding tokens B, N, C = x.shape + if mask is None: + Bc, Nc, _ = cond.shape + assert Bc == B + mask = [Nc] * B + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) k, v = kv.unbind(2) @@ -504,6 +509,11 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention): B, SUB_N, C = x.shape # [B, TS/p, C] N = SUB_N * sp_size + if mask is None: + Bc, Nc, _ = cond.shape + assert Bc == B + mask = [Nc] * B + # shape: # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM] q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim) diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index d0607ff..78b1631 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -226,8 +226,8 @@ def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", stri state_dict = load_file(ckpt_path) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - print(f"Missing keys: {missing_keys}") - print(f"Unexpected keys: {unexpected_keys}") + get_logger().info("Missing keys: %s", missing_keys) + get_logger().info("Unexpected keys: %s", unexpected_keys) elif os.path.isdir(ckpt_path): load_from_sharded_state_dict(model, ckpt_path, model_name, strict=strict) get_logger().info("Model checkpoint loaded from %s", ckpt_path) diff --git a/opensora/utils/train_utils.py b/opensora/utils/train_utils.py index f015b44..e542686 100644 --- a/opensora/utils/train_utils.py +++ b/opensora/utils/train_utils.py @@ -153,9 +153,9 @@ class MaskGenerator: elif mask_name == "random": mask_ratio = random.uniform(0.1, 0.9) mask = torch.rand(num_frames, device=x.device) > mask_ratio - # if mask is all False, set the last frame to True - if not mask.any(): - mask[-1] = 1 + # if mask is all False, set the last frame to True + if not mask.any(): + mask[-1] = 1 return mask