mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
Merge branch 'dev/v1.0.1' of https://github.com/hpcaitech/Open-Sora-dev into dev/v1.0.1
This commit is contained in:
commit
3210a691cf
|
|
@ -53,7 +53,7 @@ def main(args):
|
||||||
writer.writerow(["video", "text"])
|
writer.writerow(["video", "text"])
|
||||||
|
|
||||||
# make sure that the prompt type matches the data type
|
# make sure that the prompt type matches the data type
|
||||||
data_extension = dataset.data["path"].iloc[0].split(".")[-1]
|
data_extension = "." + dataset.data["path"].iloc[0].split(".")[-1]
|
||||||
prompt_type = PROMPTS[args.prompt]["type"]
|
prompt_type = PROMPTS[args.prompt]["type"]
|
||||||
if prompt_type == "image":
|
if prompt_type == "image":
|
||||||
assert (
|
assert (
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from llava.conversation import conv_templates
|
||||||
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
|
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
|
||||||
from llava.model.builder import load_pretrained_model
|
from llava.model.builder import load_pretrained_model
|
||||||
from llava.utils import disable_torch_init
|
from llava.utils import disable_torch_init
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .acceleration.llava.policy import LlavaForCausalLMPolicy
|
from .acceleration.llava.policy import LlavaForCausalLMPolicy
|
||||||
|
|
@ -22,6 +23,38 @@ from .utils import IMG_EXTENSIONS, PROMPTS, VID_EXTENSIONS, Timer, VideoTextData
|
||||||
disable_torch_init()
|
disable_torch_init()
|
||||||
|
|
||||||
|
|
||||||
|
class NoPaddingDistributedSampler(DistributedSampler):
|
||||||
|
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False):
|
||||||
|
super().__init__(
|
||||||
|
dataset=dataset, num_replicas=num_replicas, rank=rank, seed=seed, shuffle=False, drop_last=False
|
||||||
|
)
|
||||||
|
remainder = len(self.dataset) % self.num_replicas
|
||||||
|
if remainder > 0 and (self.rank + 1) - remainder <= 0:
|
||||||
|
# if the dataset is not divisible by num_replicas
|
||||||
|
# the remaining items will be allocated to the first n ranks
|
||||||
|
self.num_samples = len(self.dataset) // self.num_replicas + 1
|
||||||
|
else:
|
||||||
|
self.num_samples = len(self.dataset) // self.num_replicas
|
||||||
|
self.total_size = len(dataset)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.shuffle:
|
||||||
|
# deterministically shuffle based on epoch and seed
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.seed + self.epoch)
|
||||||
|
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# remove tail of data to make it evenly divisible.
|
||||||
|
indices = indices[: self.total_size]
|
||||||
|
|
||||||
|
# subsample
|
||||||
|
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||||
|
assert len(indices) == self.num_samples
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def main(args):
|
def main(args):
|
||||||
# ======================================================
|
# ======================================================
|
||||||
|
|
@ -120,16 +153,16 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure that the prompt type matches the data type
|
# make sure that the prompt type matches the data type
|
||||||
data_extension = dataset.data["path"].iloc[0].split(".")[-1]
|
data_extension = "." + dataset.data["path"].iloc[0].split(".")[-1]
|
||||||
prompt_type = PROMPTS[args.prompt]["type"]
|
prompt_type = PROMPTS[args.prompt]["type"]
|
||||||
if prompt_type == "image":
|
if prompt_type == "image":
|
||||||
assert (
|
assert (
|
||||||
data_extension.lower() in IMG_EXTENSIONS
|
data_extension.lower() in IMG_EXTENSIONS
|
||||||
), "The prompt is suitable for an image dataset but the data is not image."
|
), f"The prompt is suitable for an image dataset but the data is not image. The first data is of format {data_extension}"
|
||||||
elif prompt_type == "video":
|
elif prompt_type == "video":
|
||||||
assert (
|
assert (
|
||||||
data_extension.lower() in VID_EXTENSIONS
|
data_extension.lower() in VID_EXTENSIONS
|
||||||
), "The prompt is suitable for a video dataset but the data is not video."
|
), f"The prompt is suitable for a video dataset but the data is not video. The first data is of format {data_extension}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Found invalid prompt type {prompt_type}")
|
raise ValueError(f"Found invalid prompt type {prompt_type}")
|
||||||
|
|
||||||
|
|
@ -138,9 +171,7 @@ def main(args):
|
||||||
# build sampler
|
# build sampler
|
||||||
dp_rank = dist.get_rank(dp_group)
|
dp_rank = dist.get_rank(dp_group)
|
||||||
dp_size = dist.get_world_size(dp_group)
|
dp_size = dist.get_world_size(dp_group)
|
||||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
sampler = NoPaddingDistributedSampler(dataset, rank=dp_rank, num_replicas=dp_size)
|
||||||
dataset, rank=dp_rank, num_replicas=dp_size, shuffle=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue