mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 04:37:45 +02:00
fix bug at mha, MaskGenerator; improve ckpt_utils.py (#609)
* fix bug at mha in blocks.py * fix bug in MaskGenerator * align logging style in ckpt_utils.py
This commit is contained in:
parent
a82687c11e
commit
df5668cdf1
|
|
@ -467,6 +467,11 @@ class MultiHeadCrossAttention(nn.Module):
|
|||
# query/value: img tokens; key: condition; mask: if padding tokens
|
||||
B, N, C = x.shape
|
||||
|
||||
if mask is None:
|
||||
Bc, Nc, _ = cond.shape
|
||||
assert Bc == B
|
||||
mask = [Nc] * B
|
||||
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
|
@ -504,6 +509,11 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
|
|||
B, SUB_N, C = x.shape # [B, TS/p, C]
|
||||
N = SUB_N * sp_size
|
||||
|
||||
if mask is None:
|
||||
Bc, Nc, _ = cond.shape
|
||||
assert Bc == B
|
||||
mask = [Nc] * B
|
||||
|
||||
# shape:
|
||||
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
|
||||
q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
|
||||
|
|
|
|||
|
|
@ -226,8 +226,8 @@ def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", stri
|
|||
|
||||
state_dict = load_file(ckpt_path)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"Missing keys: {missing_keys}")
|
||||
print(f"Unexpected keys: {unexpected_keys}")
|
||||
get_logger().info("Missing keys: %s", missing_keys)
|
||||
get_logger().info("Unexpected keys: %s", unexpected_keys)
|
||||
elif os.path.isdir(ckpt_path):
|
||||
load_from_sharded_state_dict(model, ckpt_path, model_name, strict=strict)
|
||||
get_logger().info("Model checkpoint loaded from %s", ckpt_path)
|
||||
|
|
|
|||
|
|
@ -153,9 +153,9 @@ class MaskGenerator:
|
|||
elif mask_name == "random":
|
||||
mask_ratio = random.uniform(0.1, 0.9)
|
||||
mask = torch.rand(num_frames, device=x.device) > mask_ratio
|
||||
# if mask is all False, set the last frame to True
|
||||
if not mask.any():
|
||||
mask[-1] = 1
|
||||
# if mask is all False, set the last frame to True
|
||||
if not mask.any():
|
||||
mask[-1] = 1
|
||||
|
||||
return mask
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue