Merge pull request #530 from hpcaitech/hotfix/read

handle av error
This commit is contained in:
Zheng Zangwei (Alex Zheng) 2024-06-22 21:27:13 +08:00 committed by GitHub
commit 3cd29bd488

View file

@ -89,18 +89,22 @@ def read_video_av(
video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8)
# == read ==
# TODO: The reading has memory leak (4G for 8 workers 1 GPU)
container = av.open(filename, metadata_errors="ignore")
assert container.streams.video is not None
video_frames = _read_from_stream(
video_frames,
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
try:
# TODO: The reading has memory leak (4G for 8 workers 1 GPU)
container = av.open(filename, metadata_errors="ignore")
assert container.streams.video is not None
video_frames = _read_from_stream(
video_frames,
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
filename=filename,
)
except av.AVError as e:
print(f"[Warning] Error while reading video {filename}: {e}")
vframes = torch.from_numpy(video_frames).clone()
del video_frames
@ -120,6 +124,7 @@ def _read_from_stream(
pts_unit: str,
stream: "av.stream.Stream",
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
filename: Optional[str] = None,
) -> List["av.frame.Frame"]:
if pts_unit == "sec":
@ -159,26 +164,28 @@ def _read_from_stream(
try:
# TODO check if stream needs to always be the video stream here or not
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
except av.AVError:
# TODO add some warnings in this case
# print("Corrupted file?", container.name)
except av.AVError as e:
print(f"[Warning] Error while seeking video {filename}: {e}")
return []
# == main ==
buffer_count = 0
frames_pts = []
cnt = 0
for _idx, frame in enumerate(container.decode(**stream_name)):
frames_pts.append(frame.pts)
video_frames[cnt] = frame.to_rgb().to_ndarray()
cnt += 1
if cnt >= len(video_frames):
break
if frame.pts >= end_offset:
if should_buffer and buffer_count < max_buffer_size:
buffer_count += 1
continue
break
try:
for _idx, frame in enumerate(container.decode(**stream_name)):
frames_pts.append(frame.pts)
video_frames[cnt] = frame.to_rgb().to_ndarray()
cnt += 1
if cnt >= len(video_frames):
break
if frame.pts >= end_offset:
if should_buffer and buffer_count < max_buffer_size:
buffer_count += 1
continue
break
except av.AVError as e:
print(f"[Warning] Error while reading video {filename}: {e}")
# garbage collection for thread leakage
container.close()