From 11978080f4c61638bd53751b7be6a0da35540af3 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 8 Apr 2024 10:07:43 +0800 Subject: [PATCH 1/2] enabled uneven data sharding for captioning (#41) * enabled uneven data sharding for captioning * poslih --- tools/caption/caption_llava.py | 37 +++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/tools/caption/caption_llava.py b/tools/caption/caption_llava.py index 3ee21b6..b93b76f 100644 --- a/tools/caption/caption_llava.py +++ b/tools/caption/caption_llava.py @@ -14,6 +14,7 @@ from llava.conversation import conv_templates from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init +from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm from .acceleration.llava.policy import LlavaForCausalLMPolicy @@ -22,6 +23,38 @@ from .utils import IMG_EXTENSIONS, PROMPTS, VID_EXTENSIONS, Timer, VideoTextData disable_torch_init() +class NoPaddingDistributedSampler(DistributedSampler): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False): + super().__init__( + dataset=dataset, num_replicas=num_replicas, rank=rank, seed=seed, shuffle=False, drop_last=False + ) + remainder = len(self.dataset) % self.num_replicas + if remainder > 0 and (self.rank + 1) - remainder <= 0: + # if the dataset is not divisible by num_replicas + # the remaining items will be allocated to the first n ranks + self.num_samples = len(self.dataset) // self.num_replicas + 1 + else: + self.num_samples = len(self.dataset) // self.num_replicas + self.total_size = len(dataset) + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + return iter(indices) + + @torch.inference_mode() def main(args): # ====================================================== @@ -138,9 +171,7 @@ def main(args): # build sampler dp_rank = dist.get_rank(dp_group) dp_size = dist.get_world_size(dp_group) - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, rank=dp_rank, num_replicas=dp_size, shuffle=False - ) + sampler = NoPaddingDistributedSampler(dataset, rank=dp_rank, num_replicas=dp_size) # build dataloader dataloader = torch.utils.data.DataLoader( From f0d4d55ed7a6fd731e975f5473d5360898a11f86 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 8 Apr 2024 10:38:14 +0800 Subject: [PATCH 2/2] fixed captioning data format check (#42) --- tools/caption/caption_gpt4.py | 2 +- tools/caption/caption_llava.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/caption/caption_gpt4.py b/tools/caption/caption_gpt4.py index 667000f..f22c296 100644 --- a/tools/caption/caption_gpt4.py +++ b/tools/caption/caption_gpt4.py @@ -53,7 +53,7 @@ def main(args): writer.writerow(["video", "text"]) # make sure that the prompt type matches the data type - data_extension = dataset.data["path"].iloc[0].split(".")[-1] + data_extension = "." + dataset.data["path"].iloc[0].split(".")[-1] prompt_type = PROMPTS[args.prompt]["type"] if prompt_type == "image": assert ( diff --git a/tools/caption/caption_llava.py b/tools/caption/caption_llava.py index b93b76f..2f75595 100644 --- a/tools/caption/caption_llava.py +++ b/tools/caption/caption_llava.py @@ -153,16 +153,16 @@ def main(args): ) # make sure that the prompt type matches the data type - data_extension = dataset.data["path"].iloc[0].split(".")[-1] + data_extension = "." + dataset.data["path"].iloc[0].split(".")[-1] prompt_type = PROMPTS[args.prompt]["type"] if prompt_type == "image": assert ( data_extension.lower() in IMG_EXTENSIONS - ), "The prompt is suitable for an image dataset but the data is not image." + ), f"The prompt is suitable for an image dataset but the data is not image. The first data is of format {data_extension}" elif prompt_type == "video": assert ( data_extension.lower() in VID_EXTENSIONS - ), "The prompt is suitable for a video dataset but the data is not video." + ), f"The prompt is suitable for a video dataset but the data is not video. The first data is of format {data_extension}" else: raise ValueError(f"Found invalid prompt type {prompt_type}")