Merge branch 'dev/v1.0.1' of https://github.com/hpcaitech/Open-Sora-dev into dev/v1.0.1

This commit is contained in:
Zangwei Zheng 2024-04-08 14:12:57 +08:00
commit 3210a691cf
2 changed files with 38 additions and 7 deletions

View file

@ -53,7 +53,7 @@ def main(args):
writer.writerow(["video", "text"]) writer.writerow(["video", "text"])
# make sure that the prompt type matches the data type # make sure that the prompt type matches the data type
data_extension = dataset.data["path"].iloc[0].split(".")[-1] data_extension = "." + dataset.data["path"].iloc[0].split(".")[-1]
prompt_type = PROMPTS[args.prompt]["type"] prompt_type = PROMPTS[args.prompt]["type"]
if prompt_type == "image": if prompt_type == "image":
assert ( assert (

View file

@ -14,6 +14,7 @@ from llava.conversation import conv_templates
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.model.builder import load_pretrained_model from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init from llava.utils import disable_torch_init
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from .acceleration.llava.policy import LlavaForCausalLMPolicy from .acceleration.llava.policy import LlavaForCausalLMPolicy
@ -22,6 +23,38 @@ from .utils import IMG_EXTENSIONS, PROMPTS, VID_EXTENSIONS, Timer, VideoTextData
disable_torch_init() disable_torch_init()
class NoPaddingDistributedSampler(DistributedSampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False):
super().__init__(
dataset=dataset, num_replicas=num_replicas, rank=rank, seed=seed, shuffle=False, drop_last=False
)
remainder = len(self.dataset) % self.num_replicas
if remainder > 0 and (self.rank + 1) - remainder <= 0:
# if the dataset is not divisible by num_replicas
# the remaining items will be allocated to the first n ranks
self.num_samples = len(self.dataset) // self.num_replicas + 1
else:
self.num_samples = len(self.dataset) // self.num_replicas
self.total_size = len(dataset)
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
@torch.inference_mode() @torch.inference_mode()
def main(args): def main(args):
# ====================================================== # ======================================================
@ -120,16 +153,16 @@ def main(args):
) )
# make sure that the prompt type matches the data type # make sure that the prompt type matches the data type
data_extension = dataset.data["path"].iloc[0].split(".")[-1] data_extension = "." + dataset.data["path"].iloc[0].split(".")[-1]
prompt_type = PROMPTS[args.prompt]["type"] prompt_type = PROMPTS[args.prompt]["type"]
if prompt_type == "image": if prompt_type == "image":
assert ( assert (
data_extension.lower() in IMG_EXTENSIONS data_extension.lower() in IMG_EXTENSIONS
), "The prompt is suitable for an image dataset but the data is not image." ), f"The prompt is suitable for an image dataset but the data is not image. The first data is of format {data_extension}"
elif prompt_type == "video": elif prompt_type == "video":
assert ( assert (
data_extension.lower() in VID_EXTENSIONS data_extension.lower() in VID_EXTENSIONS
), "The prompt is suitable for a video dataset but the data is not video." ), f"The prompt is suitable for a video dataset but the data is not video. The first data is of format {data_extension}"
else: else:
raise ValueError(f"Found invalid prompt type {prompt_type}") raise ValueError(f"Found invalid prompt type {prompt_type}")
@ -138,9 +171,7 @@ def main(args):
# build sampler # build sampler
dp_rank = dist.get_rank(dp_group) dp_rank = dist.get_rank(dp_group)
dp_size = dist.get_world_size(dp_group) dp_size = dist.get_world_size(dp_group)
sampler = torch.utils.data.distributed.DistributedSampler( sampler = NoPaddingDistributedSampler(dataset, rank=dp_rank, num_replicas=dp_size)
dataset, rank=dp_rank, num_replicas=dp_size, shuffle=False
)
# build dataloader # build dataloader
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(