mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-06 04:00:01 +02:00
Merge branch 'main' of github.com:hpcaitech/Open-Sora-dev into main
This commit is contained in:
commit
eb0ba30484
|
|
@ -19,8 +19,8 @@ Note that currently our model loading for vae and diffusion model supports two t
|
|||
* load from local file path
|
||||
* load from huggingface
|
||||
|
||||
Our config supports loading from huggingface by default.
|
||||
If you wish to load from a local path, you need to set `force_huggingface=True`, for instance:
|
||||
Our config supports loading from huggingface online image by default.
|
||||
If you wish to load from a local path downloaded from huggingface image, you need to set `force_huggingface=True`, for instance:
|
||||
|
||||
```python
|
||||
# for vae
|
||||
|
|
@ -41,6 +41,7 @@ model = dict(
|
|||
force_huggingface=True, # NOTE: set here
|
||||
)
|
||||
```
|
||||
However, if you want to load a self-trained model, do not set `force_huggingface=True` since your image won't be in huggingface format.
|
||||
|
||||
## Inference
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,16 @@
|
|||
CMD="torchrun --standalone --nproc_per_node 1 eval/loss/eval_loss.py configs/opensora-v1-2/misc/eval_loss.py"
|
||||
CKPT_PATH=$1
|
||||
MODEL_NAME=$2
|
||||
IMG_PATH="/mnt/jfs-hdd/sora/meta/validation/img_1k.csv"
|
||||
VID_PATH="/mnt/jfs-hdd/sora/meta/validation/vid_100.csv"
|
||||
IMG_PATH=$3
|
||||
VID_PATH=$4
|
||||
|
||||
if [ -z $IMG_PATH ]; then
|
||||
IMG_PATH="/mnt/jfs-hdd/sora/meta/validation/img_1k.csv"
|
||||
fi
|
||||
|
||||
if [ -z $VID_PATH ]; then
|
||||
VID_PATH="/mnt/jfs-hdd/sora/meta/validation/vid_100.csv"
|
||||
fi
|
||||
|
||||
if [[ $CKPT_PATH == *"ema"* ]]; then
|
||||
parentdir=$(dirname $CKPT_PATH)
|
||||
|
|
|
|||
|
|
@ -163,6 +163,8 @@ class Attention(nn.Module):
|
|||
if rope is not None:
|
||||
self.rope = True
|
||||
self.rotary_emb = rope
|
||||
|
||||
self.is_causal = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
|
|
@ -198,12 +200,17 @@ class Attention(nn.Module):
|
|||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=self.is_causal,
|
||||
)
|
||||
else:
|
||||
dtype = q.dtype
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1) # translate attn to float32
|
||||
attn = attn.to(torch.float32)
|
||||
if self.is_causal:
|
||||
causal_mask = torch.tril(torch.ones_like(attn), diagonal=0)
|
||||
causal_mask = torch.where(causal_mask.bool(), 0, float('-inf'))
|
||||
attn += causal_mask
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = attn.to(dtype) # cast back attn to original dtype
|
||||
attn = self.attn_drop(attn)
|
||||
|
|
@ -499,7 +506,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
|
|||
|
||||
# shape:
|
||||
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down")
|
||||
k, v = kv.unbind(2)
|
||||
|
|
|
|||
|
|
@ -368,12 +368,17 @@ class STDiT3(PreTrainedModel):
|
|||
# for simplicity, we can adjust the height to make it divisible
|
||||
if self.enable_sequence_parallelism:
|
||||
sp_size = dist.get_world_size(get_sequence_parallel_group())
|
||||
h_pad_size = sp_size - H % sp_size
|
||||
hx_pad_size = h_pad_size * self.patch_size[1]
|
||||
if H % sp_size != 0:
|
||||
h_pad_size = sp_size - H % sp_size
|
||||
else:
|
||||
h_pad_size = 0
|
||||
|
||||
# pad x along the H dimension
|
||||
H += h_pad_size
|
||||
x = F.pad(x, (0, 0, 0, hx_pad_size))
|
||||
if h_pad_size > 0:
|
||||
hx_pad_size = h_pad_size * self.patch_size[1]
|
||||
|
||||
# pad x along the H dimension
|
||||
H += h_pad_size
|
||||
x = F.pad(x, (0, 0, 0, hx_pad_size))
|
||||
|
||||
S = H * W
|
||||
base_size = round(S**0.5)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from opensora.models.stdit.stdit3 import STDiT3, STDiT3Config
|
|||
|
||||
|
||||
def get_sample_data():
|
||||
x = torch.rand([1, 4, 15, 20, 27], dtype=torch.bfloat16) # (B, C, T, H, W)
|
||||
x = torch.rand([1, 4, 15, 20, 28], dtype=torch.bfloat16) # (B, C, T, H, W)
|
||||
timestep = torch.Tensor([924.0]).to(torch.bfloat16)
|
||||
y = torch.rand(1, 1, 300, 4096, dtype=torch.bfloat16)
|
||||
mask = torch.ones([1, 300], dtype=torch.int32)
|
||||
|
|
@ -66,6 +66,17 @@ def run_model(rank, world_size, port):
|
|||
set_seed(1024)
|
||||
dist_model_cfg = get_stdit3_config(enable_sequence_parallelism=True)
|
||||
dist_model = STDiT3(dist_model_cfg).cuda().to(torch.bfloat16)
|
||||
|
||||
# ensure model weights are equal
|
||||
for p1, p2 in zip(non_dist_model.parameters(), dist_model.parameters()):
|
||||
assert torch.equal(p1, p2)
|
||||
|
||||
# ensure model weights are equal across all ranks
|
||||
for p in dist_model.parameters():
|
||||
p_list = [torch.zeros_like(p) for _ in range(world_size)]
|
||||
dist.all_gather(p_list, p, group=dist.group.WORLD)
|
||||
assert torch.equal(*p_list)
|
||||
|
||||
dist_out = dist_model(**data)
|
||||
dist_out.mean().backward()
|
||||
|
||||
|
|
@ -84,9 +95,8 @@ def run_model(rank, world_size, port):
|
|||
for (n1, p1), (n2, p2) in zip(non_dist_model.named_parameters(), dist_model.named_parameters()):
|
||||
assert n1 == n2
|
||||
if p1.grad is not None and p2.grad is not None:
|
||||
if not torch.allclose(p1.grad, p2.grad, rtol=1e-2, atol=1e-4):
|
||||
if dist.get_rank() == 0:
|
||||
print(f"gradient of {n1} is not equal, {p1.grad} vs {p2.grad}")
|
||||
if not torch.allclose(p1.grad, p2.grad, rtol=1e-2, atol=1e-4) and dist.get_rank() == 0:
|
||||
print(f"gradient of {n1} is not equal, {p1.grad} vs {p2.grad}")
|
||||
else:
|
||||
assert p1.grad is None and p2.grad is None
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Human labeling of videos is expensive and time-consuming. We adopt powerful imag
|
|||
|
||||
## PLLaVA Captioning
|
||||
|
||||
To balance captioning speed and performance, we chose the 13B version of PLLaVA configured with 2*2 spatial pooling. We feed it with 4 frames evenly extracted from the video.
|
||||
To balance captioning speed and performance, we chose the 13B version of PLLaVA configured with 2*2 spatial pooling. We feed it with 4 frames evenly extracted from the video. We accelerate its inference via (1) batching and (2) offload frame extraction to a separate process such that the GPU computations and frame extraction happen in parallel.
|
||||
|
||||
### Installation
|
||||
Install the required dependancies by following our [installation instructions](../../docs/installation.md)'s "Data Dependencies" and "PLLaVA Captioning" sections.
|
||||
|
|
|
|||
|
|
@ -29,15 +29,20 @@ def process_single_row(row, args):
|
|||
# check mp4 integrity
|
||||
# if not is_intact_video(video_path, logger=logger):
|
||||
# return False
|
||||
|
||||
if "timestamp" in row:
|
||||
timestamp = row["timestamp"]
|
||||
if not (timestamp.startswith("[") and timestamp.endswith("]")):
|
||||
try:
|
||||
if "timestamp" in row:
|
||||
timestamp = row["timestamp"]
|
||||
if not (timestamp.startswith("[") and timestamp.endswith("]")):
|
||||
return False
|
||||
scene_list = eval(timestamp)
|
||||
scene_list = [(FrameTimecode(s, fps=1), FrameTimecode(t, fps=1)) for s, t in scene_list]
|
||||
else:
|
||||
scene_list = [None]
|
||||
if args.drop_invalid_timestamps:
|
||||
return True
|
||||
except Exception as e:
|
||||
if args.drop_invalid_timestamps:
|
||||
return False
|
||||
scene_list = eval(timestamp)
|
||||
scene_list = [(FrameTimecode(s, fps=1), FrameTimecode(t, fps=1)) for s, t in scene_list]
|
||||
else:
|
||||
scene_list = [None]
|
||||
|
||||
if "relpath" in row:
|
||||
save_dir = os.path.dirname(os.path.join(args.save_dir, row["relpath"]))
|
||||
|
|
@ -61,7 +66,7 @@ def process_single_row(row, args):
|
|||
shorter_size=shorter_size,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def split_video(
|
||||
video_path,
|
||||
|
|
@ -108,7 +113,10 @@ def split_video(
|
|||
fname_wo_ext = os.path.splitext(fname)[0]
|
||||
# TODO: fname pattern
|
||||
save_path = os.path.join(save_dir, f"{fname_wo_ext}_scene-{idx}.mp4")
|
||||
|
||||
if os.path.exists(save_path):
|
||||
# print_log(f"File '{save_path}' already exists. Skip.", logger=logger)
|
||||
continue
|
||||
|
||||
# ffmpeg cmd
|
||||
cmd = [FFMPEG_PATH]
|
||||
|
||||
|
|
@ -134,7 +142,7 @@ def split_video(
|
|||
# cmd += ['-vf', f"scale='if(gt(iw,ih),{shorter_size},trunc(ow/a/2)*2)':-2"]
|
||||
|
||||
cmd += ["-map", "0:v", save_path]
|
||||
|
||||
# print(cmd)
|
||||
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
stdout, stderr = proc.communicate()
|
||||
# stdout = stdout.decode("utf-8")
|
||||
|
|
@ -163,7 +171,7 @@ def parse_args():
|
|||
)
|
||||
parser.add_argument("--num_workers", type=int, default=None, help="#workers for pandarallel")
|
||||
parser.add_argument("--disable_parallel", action="store_true", help="disable parallel processing")
|
||||
|
||||
parser.add_argument("--drop_invalid_timestamps", action="store_true", help="drop rows with invalid timestamps")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
|
@ -175,7 +183,7 @@ def main():
|
|||
print(f"Meta file '{meta_path}' not found. Exit.")
|
||||
exit()
|
||||
|
||||
# create logger
|
||||
# create save_dir
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
|
||||
# initialize pandarallel
|
||||
|
|
@ -189,10 +197,13 @@ def main():
|
|||
# process
|
||||
meta = pd.read_csv(args.meta_path)
|
||||
if not args.disable_parallel:
|
||||
meta.parallel_apply(process_single_row_partial, axis=1)
|
||||
results = meta.parallel_apply(process_single_row_partial, axis=1)
|
||||
else:
|
||||
meta.apply(process_single_row_partial, axis=1)
|
||||
|
||||
|
||||
results = meta.apply(process_single_row_partial, axis=1)
|
||||
if args.drop_invalid_timestamps:
|
||||
meta = meta[results]
|
||||
assert args.meta_path.endswith("timestamp.csv"), "Only support *timestamp.csv"
|
||||
meta.to_csv(args.meta_path.replace("timestamp.csv", "correct_timestamp.csv"), index=False)
|
||||
print(f"Corrected timestamp file saved to '{args.meta_path.replace('timestamp.csv', 'correct_timestamp.csv')}'")
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue