mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-20 17:35:58 +02:00
Merge branch 'dev/v1.0.1' of https://github.com/hpcaitech/Open-Sora-dev into dev/v1.0.1
This commit is contained in:
commit
36268b1c0f
2
tools/caption/acceleration/llava/policies/__init__.py
Normal file
2
tools/caption/acceleration/llava/policies/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .llama import LlavaLlamaForCausalLMPolicy
|
||||
from .mistral import LlavaMistralForCausalLMPolicy
|
||||
|
|
@ -1,26 +1,25 @@
|
|||
import warnings
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["LlavaPolicy", "LlavaForCausalLMPolicy"]
|
||||
__all__ = ["LlavaLlamaPolicy", "LlavaLlamaForCausalLMPolicy"]
|
||||
|
||||
|
||||
class LlavaPolicy(Policy):
|
||||
class LlavaLlamaPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
self.model.config.vocab_size
|
||||
self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
# if vocab_size % world_size != 0:
|
||||
# new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
# self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
|
|
@ -29,10 +28,6 @@ class LlavaPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
|
|
@ -83,7 +78,7 @@ class LlavaPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
|
||||
class LlavaForCausalLMPolicy(LlavaPolicy):
|
||||
class LlavaLlamaForCausalLMPolicy(LlavaLlamaPolicy):
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
113
tools/caption/acceleration/llava/policies/mistral.py
Normal file
113
tools/caption/acceleration/llava/policies/mistral.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
import warnings
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["LlavaMistralPolicy", "LlavaMistralForCausalLMPolicy"]
|
||||
|
||||
|
||||
class LlavaMistralPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralModel
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn(
|
||||
"Mistral doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
|
||||
policy[MistralDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MistralModel,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
class LlavaMistralForCausalLMPolicy(LlavaMistralPolicy):
|
||||
def module_policy(self):
|
||||
from transformers import MistralForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
MistralForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
|
@ -17,7 +17,7 @@ from llava.utils import disable_torch_init
|
|||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from .acceleration.llava.policy import LlavaForCausalLMPolicy
|
||||
from .acceleration.llava.policies import LlavaLlamaForCausalLMPolicy, LlavaMistralForCausalLMPolicy
|
||||
from .utils import IMG_EXTENSIONS, PROMPTS, VID_EXTENSIONS, Timer, VideoTextDataset, collate_fn
|
||||
|
||||
disable_torch_init()
|
||||
|
|
@ -93,14 +93,22 @@ def main(args):
|
|||
# ======================================================
|
||||
# 3. Apply system optimization
|
||||
# ======================================================
|
||||
# create huggingface model as normal
|
||||
tp_size = dist.get_world_size(tp_group)
|
||||
shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group if tp_size > 1 else None,
|
||||
enable_tensor_parallelism=True if tp_size > 1 else False,
|
||||
)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
model = shard_former.optimize(model, policy=LlavaForCausalLMPolicy())[0].cuda()
|
||||
|
||||
# check the model type
|
||||
model_name = model.__class__.__name__
|
||||
print(model_name)
|
||||
if model_name == "LlavaLlamaForCausalLM":
|
||||
model = shard_former.optimize(model, policy=LlavaLlamaForCausalLMPolicy())[0].cuda()
|
||||
elif model_name == "LlavaMistralForCausalLM":
|
||||
model = shard_former.optimize(model, policy=LlavaMistralForCausalLMPolicy())[0].cuda()
|
||||
else:
|
||||
print(f"The shardformer policy for {model_name} is not implemented, skip")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# ======================================================
|
||||
|
|
@ -209,7 +217,6 @@ def main(args):
|
|||
|
||||
if args.profile:
|
||||
encode_time = []
|
||||
frame_extraction_time = []
|
||||
generate_time = []
|
||||
output_length = []
|
||||
total_time = []
|
||||
|
|
@ -292,7 +299,6 @@ def main(args):
|
|||
print(output_length)
|
||||
num_samples_after_warmup = total_num_videos - args.bs * args.profile_warmup * dp_size
|
||||
print(f"throughput (samples/s): {num_samples_after_warmup / sum(total_time)}")
|
||||
print(f"average frame extraction time per sample: {sum(frame_extraction_time) / num_samples_after_warmup}")
|
||||
print(f"average encode time per sample: {sum(encode_time) / num_samples_after_warmup}")
|
||||
print(f"average generate time per sample: {sum(generate_time) / num_samples_after_warmup}")
|
||||
print(f"average number of tokens characters per sample: {sum(output_length) / num_samples_after_warmup}")
|
||||
|
|
|
|||
76
tools/datasets/split.py
Normal file
76
tools/datasets/split.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
import argparse
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from mmengine.config import Config
|
||||
|
||||
from opensora.datasets.bucket import Bucket
|
||||
|
||||
|
||||
def split_by_bucket(
|
||||
bucket: Bucket,
|
||||
input_files: List[str],
|
||||
output_path: str,
|
||||
limit: int,
|
||||
frame_interval: int,
|
||||
):
|
||||
print(f"Split {len(input_files)} files into {len(bucket)} buckets")
|
||||
total_limit = len(bucket) * limit
|
||||
bucket_cnt = {}
|
||||
# get all bucket id
|
||||
for hw_id, d in bucket.ar_criteria.items():
|
||||
for t_id, v in d.items():
|
||||
for ar_id in v.keys():
|
||||
bucket_id = (hw_id, t_id, ar_id)
|
||||
bucket_cnt[bucket_id] = 0
|
||||
output_df = None
|
||||
# split files
|
||||
for path in input_files:
|
||||
df = pd.read_csv(path)
|
||||
if output_df is None:
|
||||
output_df = pd.DataFrame(columns=df.columns)
|
||||
for i in range(len(df)):
|
||||
row = df.iloc[i]
|
||||
t, h, w = row["num_frames"], row["height"], row["width"]
|
||||
bucket_id = bucket.get_bucket_id(t, h, w, frame_interval)
|
||||
if bucket_id is None:
|
||||
continue
|
||||
if bucket_cnt[bucket_id] < limit:
|
||||
bucket_cnt[bucket_id] += 1
|
||||
output_df = pd.concat(
|
||||
[output_df, pd.DataFrame([row])], ignore_index=True
|
||||
)
|
||||
if len(output_df) >= total_limit:
|
||||
break
|
||||
if len(output_df) >= total_limit:
|
||||
break
|
||||
assert len(output_df) <= total_limit
|
||||
if len(output_df) == total_limit:
|
||||
print(f"All buckets are full ({total_limit} samples)")
|
||||
else:
|
||||
print(f"Only {len(output_df)} files are used")
|
||||
output_df.to_csv(output_path, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, nargs="+")
|
||||
parser.add_argument("-o", "--output", required=True)
|
||||
parser.add_argument("-c", "--config", required=True)
|
||||
parser.add_argument("-l", "--limit", default=200, type=int)
|
||||
args = parser.parse_args()
|
||||
assert args.limit > 0
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
bucket_config = cfg.bucket_config
|
||||
# rewrite bucket_config
|
||||
for ar, d in bucket_config.items():
|
||||
for frames, t in d.items():
|
||||
p, bs = t
|
||||
if p > 0.0:
|
||||
p = 1.0
|
||||
d[frames] = (p, bs)
|
||||
bucket = Bucket(bucket_config)
|
||||
split_by_bucket(
|
||||
bucket, args.input, args.output, args.limit, cfg.dataset.frame_interval
|
||||
)
|
||||
Loading…
Reference in a new issue