Merge branch 'main' of github.com:hpcaitech/Open-Sora-dev into main

This commit is contained in:
zhengzangw 2024-06-27 07:11:11 +00:00
commit eb0ba30484
7 changed files with 74 additions and 32 deletions

View file

@ -19,8 +19,8 @@ Note that currently our model loading for vae and diffusion model supports two t
* load from local file path
* load from huggingface
Our config supports loading from huggingface by default.
If you wish to load from a local path, you need to set `force_huggingface=True`, for instance:
Our config supports loading from huggingface online image by default.
If you wish to load from a local path downloaded from huggingface image, you need to set `force_huggingface=True`, for instance:
```python
# for vae
@ -41,6 +41,7 @@ model = dict(
force_huggingface=True, # NOTE: set here
)
```
However, if you want to load a self-trained model, do not set `force_huggingface=True` since your image won't be in huggingface format.
## Inference

View file

@ -3,8 +3,16 @@
CMD="torchrun --standalone --nproc_per_node 1 eval/loss/eval_loss.py configs/opensora-v1-2/misc/eval_loss.py"
CKPT_PATH=$1
MODEL_NAME=$2
IMG_PATH="/mnt/jfs-hdd/sora/meta/validation/img_1k.csv"
VID_PATH="/mnt/jfs-hdd/sora/meta/validation/vid_100.csv"
IMG_PATH=$3
VID_PATH=$4
if [ -z $IMG_PATH ]; then
IMG_PATH="/mnt/jfs-hdd/sora/meta/validation/img_1k.csv"
fi
if [ -z $VID_PATH ]; then
VID_PATH="/mnt/jfs-hdd/sora/meta/validation/vid_100.csv"
fi
if [[ $CKPT_PATH == *"ema"* ]]; then
parentdir=$(dirname $CKPT_PATH)

View file

@ -163,6 +163,8 @@ class Attention(nn.Module):
if rope is not None:
self.rope = True
self.rotary_emb = rope
self.is_causal = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
@ -198,12 +200,17 @@ class Attention(nn.Module):
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
causal=self.is_causal,
)
else:
dtype = q.dtype
q = q * self.scale
attn = q @ k.transpose(-2, -1) # translate attn to float32
attn = attn.to(torch.float32)
if self.is_causal:
causal_mask = torch.tril(torch.ones_like(attn), diagonal=0)
causal_mask = torch.where(causal_mask.bool(), 0, float('-inf'))
attn += causal_mask
attn = attn.softmax(dim=-1)
attn = attn.to(dtype) # cast back attn to original dtype
attn = self.attn_drop(attn)
@ -499,7 +506,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
# shape:
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down")
k, v = kv.unbind(2)

View file

@ -368,12 +368,17 @@ class STDiT3(PreTrainedModel):
# for simplicity, we can adjust the height to make it divisible
if self.enable_sequence_parallelism:
sp_size = dist.get_world_size(get_sequence_parallel_group())
h_pad_size = sp_size - H % sp_size
hx_pad_size = h_pad_size * self.patch_size[1]
if H % sp_size != 0:
h_pad_size = sp_size - H % sp_size
else:
h_pad_size = 0
# pad x along the H dimension
H += h_pad_size
x = F.pad(x, (0, 0, 0, hx_pad_size))
if h_pad_size > 0:
hx_pad_size = h_pad_size * self.patch_size[1]
# pad x along the H dimension
H += h_pad_size
x = F.pad(x, (0, 0, 0, hx_pad_size))
S = H * W
base_size = round(S**0.5)

View file

@ -9,7 +9,7 @@ from opensora.models.stdit.stdit3 import STDiT3, STDiT3Config
def get_sample_data():
x = torch.rand([1, 4, 15, 20, 27], dtype=torch.bfloat16) # (B, C, T, H, W)
x = torch.rand([1, 4, 15, 20, 28], dtype=torch.bfloat16) # (B, C, T, H, W)
timestep = torch.Tensor([924.0]).to(torch.bfloat16)
y = torch.rand(1, 1, 300, 4096, dtype=torch.bfloat16)
mask = torch.ones([1, 300], dtype=torch.int32)
@ -66,6 +66,17 @@ def run_model(rank, world_size, port):
set_seed(1024)
dist_model_cfg = get_stdit3_config(enable_sequence_parallelism=True)
dist_model = STDiT3(dist_model_cfg).cuda().to(torch.bfloat16)
# ensure model weights are equal
for p1, p2 in zip(non_dist_model.parameters(), dist_model.parameters()):
assert torch.equal(p1, p2)
# ensure model weights are equal across all ranks
for p in dist_model.parameters():
p_list = [torch.zeros_like(p) for _ in range(world_size)]
dist.all_gather(p_list, p, group=dist.group.WORLD)
assert torch.equal(*p_list)
dist_out = dist_model(**data)
dist_out.mean().backward()
@ -84,9 +95,8 @@ def run_model(rank, world_size, port):
for (n1, p1), (n2, p2) in zip(non_dist_model.named_parameters(), dist_model.named_parameters()):
assert n1 == n2
if p1.grad is not None and p2.grad is not None:
if not torch.allclose(p1.grad, p2.grad, rtol=1e-2, atol=1e-4):
if dist.get_rank() == 0:
print(f"gradient of {n1} is not equal, {p1.grad} vs {p2.grad}")
if not torch.allclose(p1.grad, p2.grad, rtol=1e-2, atol=1e-4) and dist.get_rank() == 0:
print(f"gradient of {n1} is not equal, {p1.grad} vs {p2.grad}")
else:
assert p1.grad is None and p2.grad is None

View file

@ -4,7 +4,7 @@ Human labeling of videos is expensive and time-consuming. We adopt powerful imag
## PLLaVA Captioning
To balance captioning speed and performance, we chose the 13B version of PLLaVA configured with 2*2 spatial pooling. We feed it with 4 frames evenly extracted from the video.
To balance captioning speed and performance, we chose the 13B version of PLLaVA configured with 2*2 spatial pooling. We feed it with 4 frames evenly extracted from the video. We accelerate its inference via (1) batching and (2) offload frame extraction to a separate process such that the GPU computations and frame extraction happen in parallel.
### Installation
Install the required dependancies by following our [installation instructions](../../docs/installation.md)'s "Data Dependencies" and "PLLaVA Captioning" sections.

View file

@ -29,15 +29,20 @@ def process_single_row(row, args):
# check mp4 integrity
# if not is_intact_video(video_path, logger=logger):
# return False
if "timestamp" in row:
timestamp = row["timestamp"]
if not (timestamp.startswith("[") and timestamp.endswith("]")):
try:
if "timestamp" in row:
timestamp = row["timestamp"]
if not (timestamp.startswith("[") and timestamp.endswith("]")):
return False
scene_list = eval(timestamp)
scene_list = [(FrameTimecode(s, fps=1), FrameTimecode(t, fps=1)) for s, t in scene_list]
else:
scene_list = [None]
if args.drop_invalid_timestamps:
return True
except Exception as e:
if args.drop_invalid_timestamps:
return False
scene_list = eval(timestamp)
scene_list = [(FrameTimecode(s, fps=1), FrameTimecode(t, fps=1)) for s, t in scene_list]
else:
scene_list = [None]
if "relpath" in row:
save_dir = os.path.dirname(os.path.join(args.save_dir, row["relpath"]))
@ -61,7 +66,7 @@ def process_single_row(row, args):
shorter_size=shorter_size,
logger=logger,
)
return True
def split_video(
video_path,
@ -108,7 +113,10 @@ def split_video(
fname_wo_ext = os.path.splitext(fname)[0]
# TODO: fname pattern
save_path = os.path.join(save_dir, f"{fname_wo_ext}_scene-{idx}.mp4")
if os.path.exists(save_path):
# print_log(f"File '{save_path}' already exists. Skip.", logger=logger)
continue
# ffmpeg cmd
cmd = [FFMPEG_PATH]
@ -134,7 +142,7 @@ def split_video(
# cmd += ['-vf', f"scale='if(gt(iw,ih),{shorter_size},trunc(ow/a/2)*2)':-2"]
cmd += ["-map", "0:v", save_path]
# print(cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
stdout, stderr = proc.communicate()
# stdout = stdout.decode("utf-8")
@ -163,7 +171,7 @@ def parse_args():
)
parser.add_argument("--num_workers", type=int, default=None, help="#workers for pandarallel")
parser.add_argument("--disable_parallel", action="store_true", help="disable parallel processing")
parser.add_argument("--drop_invalid_timestamps", action="store_true", help="drop rows with invalid timestamps")
args = parser.parse_args()
return args
@ -175,7 +183,7 @@ def main():
print(f"Meta file '{meta_path}' not found. Exit.")
exit()
# create logger
# create save_dir
os.makedirs(args.save_dir, exist_ok=True)
# initialize pandarallel
@ -189,10 +197,13 @@ def main():
# process
meta = pd.read_csv(args.meta_path)
if not args.disable_parallel:
meta.parallel_apply(process_single_row_partial, axis=1)
results = meta.parallel_apply(process_single_row_partial, axis=1)
else:
meta.apply(process_single_row_partial, axis=1)
results = meta.apply(process_single_row_partial, axis=1)
if args.drop_invalid_timestamps:
meta = meta[results]
assert args.meta_path.endswith("timestamp.csv"), "Only support *timestamp.csv"
meta.to_csv(args.meta_path.replace("timestamp.csv", "correct_timestamp.csv"), index=False)
print(f"Corrected timestamp file saved to '{args.meta_path.replace('timestamp.csv', 'correct_timestamp.csv')}'")
if __name__ == "__main__":
main()