From 9a9a6c2f3e571136978818dcd561dac915ce7157 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Sat, 22 Jun 2024 15:54:27 +0000 Subject: [PATCH 1/2] [fix] better support local ckpt --- opensora/models/stdit/stdit3.py | 5 +++-- opensora/models/vae/vae.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/opensora/models/stdit/stdit3.py b/opensora/models/stdit/stdit3.py index 8703b2d..bd9672d 100644 --- a/opensora/models/stdit/stdit3.py +++ b/opensora/models/stdit/stdit3.py @@ -448,7 +448,7 @@ class STDiT3(PreTrainedModel): @MODELS.register_module("STDiT3-XL/2") def STDiT3_XL_2(from_pretrained=None, **kwargs): force_huggingface = kwargs.pop("force_huggingface", False) - if force_huggingface or from_pretrained is not None and not os.path.isdir(from_pretrained): + if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained): model = STDiT3.from_pretrained(from_pretrained, **kwargs) else: config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) @@ -460,7 +460,8 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs): @MODELS.register_module("STDiT3-3B/2") def STDiT3_3B_2(from_pretrained=None, **kwargs): - if from_pretrained is not None and not os.path.isdir(from_pretrained): + force_huggingface = kwargs.pop("force_huggingface", False) + if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained): model = STDiT3.from_pretrained(from_pretrained, **kwargs) else: config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs) diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index bf50ec8..9802b02 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -277,7 +277,7 @@ def OpenSoraVAE_V1_2( scale=scale, ) - if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)): + if force_huggingface or (from_pretrained is not None and not os.path.exists(from_pretrained)): model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs) else: config = VideoAutoencoderPipelineConfig(**kwargs) From 00fef1d1af0b431ffd4dadea684a2d59d5d880f2 Mon Sep 17 00:00:00 2001 From: Jiacheng Yang Date: Mon, 24 Jun 2024 05:07:49 -0400 Subject: [PATCH 2/2] fix SeqParallelMultiHeadCrossAttention for consistent results in distributed mode (#510) --- opensora/models/layers/blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opensora/models/layers/blocks.py b/opensora/models/layers/blocks.py index 8bc7e72..5e2c13d 100644 --- a/opensora/models/layers/blocks.py +++ b/opensora/models/layers/blocks.py @@ -499,7 +499,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention): # shape: # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM] - q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim) kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down") k, v = kv.unbind(2)