From 6a72b8910b29ab39199d8da173c2fcd9920d3173 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Mon, 24 Jun 2024 08:50:35 +0000 Subject: [PATCH] [data] added error handling to dataset --- opensora/datasets/dataloader.py | 6 ++++++ opensora/datasets/datasets.py | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py index 15058ac..60d0b24 100644 --- a/opensora/datasets/dataloader.py +++ b/opensora/datasets/dataloader.py @@ -111,6 +111,9 @@ def prepare_dataloader( def collate_fn_default(batch): + # filter out None + batch = [x for x in batch if x is not None] + # HACK: for loading text features use_mask = False if "mask" in batch[0] and isinstance(batch[0]["mask"], int): @@ -132,6 +135,9 @@ def collate_fn_batch(batch): """ Used only with BatchDistributedSampler """ + # filter out None + batch = [x for x in batch if x is not None] + res = torch.utils.data.default_collate(batch) # squeeze the first dimension, which is due to torch.stack() in default_collate() diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index 8b5fdd6..b148268 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -190,7 +190,10 @@ class VariableVideoTextDataset(VideoTextDataset): return ret def __getitem__(self, index): - return self.getitem(index) + try: + return self.getitem(index) + except: + return None @DATASETS.register_module()