diff --git a/gradio/app.py b/gradio/app.py index 7f64896..80fc2e8 100644 --- a/gradio/app.py +++ b/gradio/app.py @@ -23,10 +23,12 @@ import gradio as gr MODEL_TYPES = ["v1.1"] CONFIG_MAP = { - "v1.1": "configs/opensora-v1-1/inference/sample-ref.py", + "v1.1-stage2": "configs/opensora-v1-1/inference/sample-ref.py", + "v1.1-stage3": "configs/opensora-v1-1/inference/sample-ref.py", } HF_STDIT_MAP = { - "v1.1": "hpcai-tech/OpenSora-STDiT-v2-stage2", + "v1.1-stage2": "hpcai-tech/OpenSora-STDiT-v2-stage2", + "v1.1-stage3": "hpcai-tech/OpenSora-STDiT-v2-stage3", } RESOLUTION_MAP = { "360p": (360, 480), @@ -249,7 +251,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model-type", - default="v1.1", + default="v1.1-stage3", choices=MODEL_TYPES, help=f"The type of model to run for the Gradio App, can only be {MODEL_TYPES}", )