Merge branch 'dev/v1.0.1' of https://github.com/hpcaitech/Open-Sora-dev into dev/v1.0.1

This commit is contained in:
Zangwei Zheng 2024-04-08 20:42:46 +08:00
commit 36268b1c0f
5 changed files with 210 additions and 18 deletions

View file

@ -0,0 +1,2 @@
from .llama import LlavaLlamaForCausalLMPolicy
from .mistral import LlavaMistralForCausalLMPolicy

View file

@ -1,26 +1,25 @@
import warnings
from typing import Dict, Union
import torch.nn as nn
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["LlavaPolicy", "LlavaForCausalLMPolicy"]
__all__ = ["LlavaLlamaPolicy", "LlavaLlamaForCausalLMPolicy"]
class LlavaPolicy(Policy):
class LlavaLlamaPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
self.model.config.vocab_size
self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
# if vocab_size % world_size != 0:
# new_vocab_size = vocab_size + world_size - vocab_size % world_size
# self.model.resize_token_embeddings(new_vocab_size)
return self.model
@ -29,10 +28,6 @@ class LlavaPolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
@ -83,7 +78,7 @@ class LlavaPolicy(Policy):
return self.model
class LlavaForCausalLMPolicy(LlavaPolicy):
class LlavaLlamaForCausalLMPolicy(LlavaLlamaPolicy):
def module_policy(self):
from transformers import LlamaForCausalLM

View file

@ -0,0 +1,113 @@
import warnings
from typing import Dict, Union
import torch.nn as nn
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["LlavaMistralPolicy", "LlavaMistralForCausalLMPolicy"]
class LlavaMistralPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralModel
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
"Mistral doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
)
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
policy[MistralDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
),
],
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=MistralModel,
)
return policy
def postprocess(self):
return self.model
class LlavaMistralForCausalLMPolicy(LlavaMistralPolicy):
def module_policy(self):
from transformers import MistralForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
MistralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
)
]
)
}
policy.update(new_item)
return policy

View file

@ -17,7 +17,7 @@ from llava.utils import disable_torch_init
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from .acceleration.llava.policy import LlavaForCausalLMPolicy
from .acceleration.llava.policies import LlavaLlamaForCausalLMPolicy, LlavaMistralForCausalLMPolicy
from .utils import IMG_EXTENSIONS, PROMPTS, VID_EXTENSIONS, Timer, VideoTextDataset, collate_fn
disable_torch_init()
@ -93,14 +93,22 @@ def main(args):
# ======================================================
# 3. Apply system optimization
# ======================================================
# create huggingface model as normal
tp_size = dist.get_world_size(tp_group)
shard_config = ShardConfig(
tensor_parallel_process_group=tp_group if tp_size > 1 else None,
enable_tensor_parallelism=True if tp_size > 1 else False,
)
shard_former = ShardFormer(shard_config=shard_config)
model = shard_former.optimize(model, policy=LlavaForCausalLMPolicy())[0].cuda()
# check the model type
model_name = model.__class__.__name__
print(model_name)
if model_name == "LlavaLlamaForCausalLM":
model = shard_former.optimize(model, policy=LlavaLlamaForCausalLMPolicy())[0].cuda()
elif model_name == "LlavaMistralForCausalLM":
model = shard_former.optimize(model, policy=LlavaMistralForCausalLMPolicy())[0].cuda()
else:
print(f"The shardformer policy for {model_name} is not implemented, skip")
torch.cuda.empty_cache()
# ======================================================
@ -209,7 +217,6 @@ def main(args):
if args.profile:
encode_time = []
frame_extraction_time = []
generate_time = []
output_length = []
total_time = []
@ -292,7 +299,6 @@ def main(args):
print(output_length)
num_samples_after_warmup = total_num_videos - args.bs * args.profile_warmup * dp_size
print(f"throughput (samples/s): {num_samples_after_warmup / sum(total_time)}")
print(f"average frame extraction time per sample: {sum(frame_extraction_time) / num_samples_after_warmup}")
print(f"average encode time per sample: {sum(encode_time) / num_samples_after_warmup}")
print(f"average generate time per sample: {sum(generate_time) / num_samples_after_warmup}")
print(f"average number of tokens characters per sample: {sum(output_length) / num_samples_after_warmup}")

76
tools/datasets/split.py Normal file
View file

@ -0,0 +1,76 @@
import argparse
from typing import List
import pandas as pd
from mmengine.config import Config
from opensora.datasets.bucket import Bucket
def split_by_bucket(
bucket: Bucket,
input_files: List[str],
output_path: str,
limit: int,
frame_interval: int,
):
print(f"Split {len(input_files)} files into {len(bucket)} buckets")
total_limit = len(bucket) * limit
bucket_cnt = {}
# get all bucket id
for hw_id, d in bucket.ar_criteria.items():
for t_id, v in d.items():
for ar_id in v.keys():
bucket_id = (hw_id, t_id, ar_id)
bucket_cnt[bucket_id] = 0
output_df = None
# split files
for path in input_files:
df = pd.read_csv(path)
if output_df is None:
output_df = pd.DataFrame(columns=df.columns)
for i in range(len(df)):
row = df.iloc[i]
t, h, w = row["num_frames"], row["height"], row["width"]
bucket_id = bucket.get_bucket_id(t, h, w, frame_interval)
if bucket_id is None:
continue
if bucket_cnt[bucket_id] < limit:
bucket_cnt[bucket_id] += 1
output_df = pd.concat(
[output_df, pd.DataFrame([row])], ignore_index=True
)
if len(output_df) >= total_limit:
break
if len(output_df) >= total_limit:
break
assert len(output_df) <= total_limit
if len(output_df) == total_limit:
print(f"All buckets are full ({total_limit} samples)")
else:
print(f"Only {len(output_df)} files are used")
output_df.to_csv(output_path, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, nargs="+")
parser.add_argument("-o", "--output", required=True)
parser.add_argument("-c", "--config", required=True)
parser.add_argument("-l", "--limit", default=200, type=int)
args = parser.parse_args()
assert args.limit > 0
cfg = Config.fromfile(args.config)
bucket_config = cfg.bucket_config
# rewrite bucket_config
for ar, d in bucket_config.items():
for frames, t in d.items():
p, bs = t
if p > 0.0:
p = 1.0
d[frames] = (p, bs)
bucket = Bucket(bucket_config)
split_by_bucket(
bucket, args.input, args.output, args.limit, cfg.dataset.frame_interval
)