mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
small improvements
This commit is contained in:
parent
60810ce6d3
commit
fdcd22257c
|
|
@ -130,7 +130,10 @@ class CSVDataset(Dataset):
|
|||
def __getitem__(self, idx):
|
||||
if idx < 0 or idx >= len(self.data_list):
|
||||
raise IndexError
|
||||
video = load_video(self.data_list[idx], self.num_frames, resolution=RESOLUTION)
|
||||
try:
|
||||
video = load_video(self.data_list[idx], self.num_frames, resolution=RESOLUTION)
|
||||
except:
|
||||
return None
|
||||
return video
|
||||
|
||||
def set_rank_and_world_size(self, rank, world_size):
|
||||
|
|
@ -191,7 +194,7 @@ def parse_args():
|
|||
"--error_message",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
default='error occured during captioning',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
|
@ -235,6 +238,9 @@ def infer(
|
|||
conv_mode,
|
||||
print_res=True,
|
||||
):
|
||||
# check if any video in video_list is None, if so, raise an exception
|
||||
if any([video is None for video in video_list]):
|
||||
raise Exception("Video not loaded properly")
|
||||
conv = conv_template.copy()
|
||||
conv.user_query("Describe the video in details.", is_mm=True)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue