diff --git a/eval/human_eval/launch.sh b/eval/human_eval/launch.sh index c5ec0a2..a753529 100644 --- a/eval/human_eval/launch.sh +++ b/eval/human_eval/launch.sh @@ -13,7 +13,8 @@ if [[ $CKPT == *"ema"* ]]; then else CKPT_BASE=$(basename $CKPT) fi -LOG_BASE=logs/sample/${MODEL_NAME}_${CKPT_BASE} +LOG_BASE=logs/samples/${MODEL_NAME}_${CKPT_BASE} +mkdir -p logs/samples echo "Logging to $LOG_BASE" GPUS=(0 1 2 3 4 5 6 7) diff --git a/eval/loss/eval_loss.py b/eval/loss/eval_loss.py index a90d400..8a97f0e 100644 --- a/eval/loss/eval_loss.py +++ b/eval/loss/eval_loss.py @@ -31,14 +31,6 @@ def main(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - # == device and dtype == - device = "cuda" if torch.cuda.is_available() else "cpu" - cfg_dtype = cfg.get("dtype", "fp32") - assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" - dtype = to_torch_dtype(cfg.get("dtype", "bf16")) - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - # == init distributed env == colossalai.launch_from_torch({}) DistCoordinator() @@ -47,7 +39,7 @@ def main(): # == init logger == logger = create_logger() - logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) + logger.info("Eval loss configuration:\n %s", pformat(cfg.to_dict())) # ====================================================== # build model & load weights diff --git a/eval/sample.sh b/eval/sample.sh index 0eae70d..c414f22 100644 --- a/eval/sample.sh +++ b/eval/sample.sh @@ -1,6 +1,6 @@ #!/bin/bash -# set -x +set -x set -e CKPT=$1 @@ -22,7 +22,6 @@ echo "QUAD_FRAMES=${QUAD_FRAMES}" echo "OCT_FRAMES=${OCT_FRAMES}" - CMD="python scripts/inference.py configs/opensora-v1-2/inference/sample.py" CMD_REF="python scripts/inference-long.py configs/opensora-v1-2/inference/sample.py" if [[ $CKPT == *"ema"* ]]; then @@ -31,7 +30,7 @@ if [[ $CKPT == *"ema"* ]]; then else CKPT_BASE=$(basename $CKPT) fi -OUTPUT="./samples/samples_${MODEL_NAME}_${CKPT_BASE}" +OUTPUT="/mnt/jfs-hdd/sora/samples/samples_${MODEL_NAME}_${CKPT_BASE}" start=$(date +%s) DEFAULT_BS=1