added camera motion to gradio (#135)

This commit is contained in:
Frank Lee 2024-06-13 18:28:20 +08:00 committed by GitHub
parent e94d3bfdef
commit 98223cf899

View file

@ -197,7 +197,7 @@ device = torch.device("cuda")
vae, text_encoder, stdit, scheduler = build_models(args.model_type, config, enable_optimization=args.enable_optimization)
def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, use_timestep_transform, reference_image, seed, sampling_steps, cfg_scale):
def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, seed, sampling_steps, cfg_scale):
torch.manual_seed(seed)
with torch.inference_mode():
# ======================
@ -267,6 +267,11 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st
# process scores
use_motion_strength = use_motion_strength and mode != "Text2Image"
if camera_motion != "none":
batch_prompts = [
f"{prompt} camera motion: {camera_motion}."
for prompt in batch_prompts
]
batch_prompts = append_score_to_prompts(
batch_prompts,
aes=aesthetic_score if use_aesthetic_score else None,
@ -302,7 +307,6 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st
scheduler_kwargs.pop('type')
scheduler_kwargs['num_sampling_steps'] = sampling_steps
scheduler_kwargs['cfg_scale'] = cfg_scale
scheduler_kwargs['use_timestep_transform'] = use_timestep_transform
scheduler.__init__(
**scheduler_kwargs
@ -344,7 +348,7 @@ def run_image_inference(
aesthetic_score,
use_motion_strength,
use_aesthetic_score,
use_timestep_transform,
camera_motion,
reference_image,
seed,
sampling_steps,
@ -359,7 +363,7 @@ def run_image_inference(
aesthetic_score,
use_motion_strength,
use_aesthetic_score,
use_timestep_transform,
camera_motion,
reference_image,
seed,
sampling_steps,
@ -375,7 +379,7 @@ def run_video_inference(
aesthetic_score,
use_motion_strength,
use_aesthetic_score,
use_timestep_transform,
camera_motion,
reference_image,
seed,
sampling_steps,
@ -394,7 +398,7 @@ def run_video_inference(
aesthetic_score,
use_motion_strength,
use_aesthetic_score,
use_timestep_transform,
camera_motion,
reference_image,
seed,
sampling_steps,
@ -498,7 +502,21 @@ def main():
)
use_aesthetic_score = gr.Checkbox(value=True, label="Enable")
use_timestep_transform = gr.Checkbox(value=True, label="Use Time Transform")
camera_motion = gr.Radio(
value="none",
label="Camera Motion",
choices=[
"none",
"pan right",
"pan left",
"tilt up",
"tilt down",
"zoom in",
"zoom out",
"static"
],
interactive=True
)
reference_image = gr.Image(
@ -519,12 +537,12 @@ def main():
image_gen_button.click(
fn=run_image_inference,
inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, use_timestep_transform, reference_image, seed, sampling_steps, cfg_scale],
inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, seed, sampling_steps, cfg_scale],
outputs=reference_image
)
video_gen_button.click(
fn=run_video_inference,
inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, use_timestep_transform, reference_image, seed, sampling_steps, cfg_scale],
inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, seed, sampling_steps, cfg_scale],
outputs=output_video
)