mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
Merge branch 'dev/v1.0.1' of https://github.com/hpcaitech/Open-Sora-dev into dev/v1.0.1
This commit is contained in:
commit
3210a691cf
|
|
@ -53,7 +53,7 @@ def main(args):
|
|||
writer.writerow(["video", "text"])
|
||||
|
||||
# make sure that the prompt type matches the data type
|
||||
data_extension = dataset.data["path"].iloc[0].split(".")[-1]
|
||||
data_extension = "." + dataset.data["path"].iloc[0].split(".")[-1]
|
||||
prompt_type = PROMPTS[args.prompt]["type"]
|
||||
if prompt_type == "image":
|
||||
assert (
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from llava.conversation import conv_templates
|
|||
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
|
||||
from llava.model.builder import load_pretrained_model
|
||||
from llava.utils import disable_torch_init
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from .acceleration.llava.policy import LlavaForCausalLMPolicy
|
||||
|
|
@ -22,6 +23,38 @@ from .utils import IMG_EXTENSIONS, PROMPTS, VID_EXTENSIONS, Timer, VideoTextData
|
|||
disable_torch_init()
|
||||
|
||||
|
||||
class NoPaddingDistributedSampler(DistributedSampler):
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False):
|
||||
super().__init__(
|
||||
dataset=dataset, num_replicas=num_replicas, rank=rank, seed=seed, shuffle=False, drop_last=False
|
||||
)
|
||||
remainder = len(self.dataset) % self.num_replicas
|
||||
if remainder > 0 and (self.rank + 1) - remainder <= 0:
|
||||
# if the dataset is not divisible by num_replicas
|
||||
# the remaining items will be allocated to the first n ranks
|
||||
self.num_samples = len(self.dataset) // self.num_replicas + 1
|
||||
else:
|
||||
self.num_samples = len(self.dataset) // self.num_replicas
|
||||
self.total_size = len(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
# deterministically shuffle based on epoch and seed
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
||||
else:
|
||||
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
||||
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[: self.total_size]
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
return iter(indices)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
# ======================================================
|
||||
|
|
@ -120,16 +153,16 @@ def main(args):
|
|||
)
|
||||
|
||||
# make sure that the prompt type matches the data type
|
||||
data_extension = dataset.data["path"].iloc[0].split(".")[-1]
|
||||
data_extension = "." + dataset.data["path"].iloc[0].split(".")[-1]
|
||||
prompt_type = PROMPTS[args.prompt]["type"]
|
||||
if prompt_type == "image":
|
||||
assert (
|
||||
data_extension.lower() in IMG_EXTENSIONS
|
||||
), "The prompt is suitable for an image dataset but the data is not image."
|
||||
), f"The prompt is suitable for an image dataset but the data is not image. The first data is of format {data_extension}"
|
||||
elif prompt_type == "video":
|
||||
assert (
|
||||
data_extension.lower() in VID_EXTENSIONS
|
||||
), "The prompt is suitable for a video dataset but the data is not video."
|
||||
), f"The prompt is suitable for a video dataset but the data is not video. The first data is of format {data_extension}"
|
||||
else:
|
||||
raise ValueError(f"Found invalid prompt type {prompt_type}")
|
||||
|
||||
|
|
@ -138,9 +171,7 @@ def main(args):
|
|||
# build sampler
|
||||
dp_rank = dist.get_rank(dp_group)
|
||||
dp_size = dist.get_world_size(dp_group)
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset, rank=dp_rank, num_replicas=dp_size, shuffle=False
|
||||
)
|
||||
sampler = NoPaddingDistributedSampler(dataset, rank=dp_rank, num_replicas=dp_size)
|
||||
|
||||
# build dataloader
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
|
|
|||
Loading…
Reference in a new issue