mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-19 09:57:27 +02:00
commit
3cd29bd488
|
|
@ -89,18 +89,22 @@ def read_video_av(
|
|||
video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8)
|
||||
|
||||
# == read ==
|
||||
# TODO: The reading has memory leak (4G for 8 workers 1 GPU)
|
||||
container = av.open(filename, metadata_errors="ignore")
|
||||
assert container.streams.video is not None
|
||||
video_frames = _read_from_stream(
|
||||
video_frames,
|
||||
container,
|
||||
start_pts,
|
||||
end_pts,
|
||||
pts_unit,
|
||||
container.streams.video[0],
|
||||
{"video": 0},
|
||||
)
|
||||
try:
|
||||
# TODO: The reading has memory leak (4G for 8 workers 1 GPU)
|
||||
container = av.open(filename, metadata_errors="ignore")
|
||||
assert container.streams.video is not None
|
||||
video_frames = _read_from_stream(
|
||||
video_frames,
|
||||
container,
|
||||
start_pts,
|
||||
end_pts,
|
||||
pts_unit,
|
||||
container.streams.video[0],
|
||||
{"video": 0},
|
||||
filename=filename,
|
||||
)
|
||||
except av.AVError as e:
|
||||
print(f"[Warning] Error while reading video {filename}: {e}")
|
||||
|
||||
vframes = torch.from_numpy(video_frames).clone()
|
||||
del video_frames
|
||||
|
|
@ -120,6 +124,7 @@ def _read_from_stream(
|
|||
pts_unit: str,
|
||||
stream: "av.stream.Stream",
|
||||
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
|
||||
filename: Optional[str] = None,
|
||||
) -> List["av.frame.Frame"]:
|
||||
|
||||
if pts_unit == "sec":
|
||||
|
|
@ -159,26 +164,28 @@ def _read_from_stream(
|
|||
try:
|
||||
# TODO check if stream needs to always be the video stream here or not
|
||||
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
|
||||
except av.AVError:
|
||||
# TODO add some warnings in this case
|
||||
# print("Corrupted file?", container.name)
|
||||
except av.AVError as e:
|
||||
print(f"[Warning] Error while seeking video {filename}: {e}")
|
||||
return []
|
||||
|
||||
# == main ==
|
||||
buffer_count = 0
|
||||
frames_pts = []
|
||||
cnt = 0
|
||||
for _idx, frame in enumerate(container.decode(**stream_name)):
|
||||
frames_pts.append(frame.pts)
|
||||
video_frames[cnt] = frame.to_rgb().to_ndarray()
|
||||
cnt += 1
|
||||
if cnt >= len(video_frames):
|
||||
break
|
||||
if frame.pts >= end_offset:
|
||||
if should_buffer and buffer_count < max_buffer_size:
|
||||
buffer_count += 1
|
||||
continue
|
||||
break
|
||||
try:
|
||||
for _idx, frame in enumerate(container.decode(**stream_name)):
|
||||
frames_pts.append(frame.pts)
|
||||
video_frames[cnt] = frame.to_rgb().to_ndarray()
|
||||
cnt += 1
|
||||
if cnt >= len(video_frames):
|
||||
break
|
||||
if frame.pts >= end_offset:
|
||||
if should_buffer and buffer_count < max_buffer_size:
|
||||
buffer_count += 1
|
||||
continue
|
||||
break
|
||||
except av.AVError as e:
|
||||
print(f"[Warning] Error while reading video {filename}: {e}")
|
||||
|
||||
# garbage collection for thread leakage
|
||||
container.close()
|
||||
|
|
|
|||
Loading…
Reference in a new issue