mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
Merge branch 'dev/v1.2' into feature/vbench_i2v
This commit is contained in:
commit
d4d7191488
|
|
@ -13,7 +13,8 @@ if [[ $CKPT == *"ema"* ]]; then
|
|||
else
|
||||
CKPT_BASE=$(basename $CKPT)
|
||||
fi
|
||||
LOG_BASE=logs/sample/${MODEL_NAME}_${CKPT_BASE}
|
||||
LOG_BASE=logs/samples/${MODEL_NAME}_${CKPT_BASE}
|
||||
mkdir -p logs/samples
|
||||
echo "Logging to $LOG_BASE"
|
||||
|
||||
GPUS=(0 1 2 3 4 5 6 7)
|
||||
|
|
|
|||
|
|
@ -31,14 +31,6 @@ def main():
|
|||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# == device and dtype ==
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
cfg_dtype = cfg.get("dtype", "fp32")
|
||||
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
|
||||
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# == init distributed env ==
|
||||
colossalai.launch_from_torch({})
|
||||
DistCoordinator()
|
||||
|
|
@ -47,7 +39,7 @@ def main():
|
|||
|
||||
# == init logger ==
|
||||
logger = create_logger()
|
||||
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
|
||||
logger.info("Eval loss configuration:\n %s", pformat(cfg.to_dict()))
|
||||
|
||||
# ======================================================
|
||||
# build model & load weights
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
# set -x
|
||||
set -x
|
||||
set -e
|
||||
|
||||
CKPT=$1
|
||||
|
|
@ -22,7 +22,6 @@ echo "QUAD_FRAMES=${QUAD_FRAMES}"
|
|||
echo "OCT_FRAMES=${OCT_FRAMES}"
|
||||
|
||||
|
||||
|
||||
CMD="python scripts/inference.py configs/opensora-v1-2/inference/sample.py"
|
||||
CMD_REF="python scripts/inference-long.py configs/opensora-v1-2/inference/sample.py"
|
||||
if [[ $CKPT == *"ema"* ]]; then
|
||||
|
|
@ -31,7 +30,7 @@ if [[ $CKPT == *"ema"* ]]; then
|
|||
else
|
||||
CKPT_BASE=$(basename $CKPT)
|
||||
fi
|
||||
OUTPUT="./samples/samples_${MODEL_NAME}_${CKPT_BASE}"
|
||||
OUTPUT="/mnt/jfs-hdd/sora/samples/samples_${MODEL_NAME}_${CKPT_BASE}"
|
||||
start=$(date +%s)
|
||||
DEFAULT_BS=1
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue