跳到内容

文本到视频

来源 https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/text_to_video

Wan-AI/Wan2.2-T2V-A14B-Diffusers 管道根据文本提示生成短视频。

本地 CLI 用法

python text_to_video.py \
  --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
  --negative_prompt "<optional quality filter>" \
  --height 480 \
  --width 640 \
  --num_frames 32 \
  --guidance_scale 4.0 \
  --guidance_scale_high 3.0 \
  --num_inference_steps 40 \
  --fps 16 \
  --output t2v_out.mp4

关键参数

  • --prompt:文本描述(字符串)。
  • --height/--width:输出分辨率(默认为 720x1280)。尺寸应与 Wan VAE 下采样(8 的倍数)保持一致。
  • --num_frames:帧数(Wan 默认值为 81)。
  • --guidance_scale--guidance_scale_high:CFG 尺度(应用于低/高)。
  • --negative_prompt:可选的要抑制的伪影列表(PR 演示使用了长串中文)。
  • --boundary_ratio:低/高 DiT 的边界分割比例。
  • --fps:保存的 MP4 的每秒帧数(需要 diffusers export_to_video)。
  • --output:保存生成视频的路径。

示例材料

text_to_video.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
from pathlib import Path

import numpy as np
import torch

from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type, is_npu


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Generate a video with Wan2.2 T2V.")
    parser.add_argument(
        "--model",
        default="Wan-AI/Wan2.2-T2V-A14B-Diffusers",
        help="Diffusers Wan2.2 model ID or local path.",
    )
    parser.add_argument("--prompt", default="A serene lakeside sunrise with mist over the water.", help="Text prompt.")
    parser.add_argument("--negative_prompt", default="", help="Negative prompt.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--guidance_scale", type=float, default=4.0, help="CFG scale (applied to low/high).")
    parser.add_argument("--guidance_scale_high", type=float, default=None, help="Optional separate CFG for high-noise.")
    parser.add_argument("--height", type=int, default=720, help="Video height.")
    parser.add_argument("--width", type=int, default=1280, help="Video width.")
    parser.add_argument("--num_frames", type=int, default=81, help="Number of frames (Wan default is 81).")
    parser.add_argument("--num_inference_steps", type=int, default=40, help="Sampling steps.")
    parser.add_argument("--boundary_ratio", type=float, default=0.875, help="Boundary split ratio for low/high DiT.")
    parser.add_argument(
        "--flow_shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)."
    )
    parser.add_argument("--output", type=str, default="wan22_output.mp4", help="Path to save the video (mp4).")
    parser.add_argument("--fps", type=int, default=24, help="Frames per second for the output video.")
    return parser.parse_args()


def main():
    args = parse_args()
    device = detect_device_type()
    generator = torch.Generator(device=device).manual_seed(args.seed)

    # Enable VAE memory optimizations on NPU
    vae_use_slicing = is_npu()
    vae_use_tiling = is_npu()

    omni = Omni(
        model=args.model,
        vae_use_slicing=vae_use_slicing,
        vae_use_tiling=vae_use_tiling,
        boundary_ratio=args.boundary_ratio,
        flow_shift=args.flow_shift,
    )

    frames = omni.generate(
        args.prompt,
        negative_prompt=args.negative_prompt,
        height=args.height,
        width=args.width,
        generator=generator,
        guidance_scale=args.guidance_scale,
        guidance_scale_2=args.guidance_scale_high,
        num_inference_steps=args.num_inference_steps,
        num_frames=args.num_frames,
    )

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    try:
        from diffusers.utils import export_to_video
    except ImportError:
        raise ImportError("diffusers is required for export_to_video.")

    # frames may be np.ndarray (preferred) or torch.Tensor
    # export_to_video expects a list of frames with values in [0, 1]
    if isinstance(frames, torch.Tensor):
        video_tensor = frames.detach().cpu()
        if video_tensor.dim() == 5:
            # [B, C, F, H, W] or [B, F, H, W, C]
            if video_tensor.shape[1] in (3, 4):
                video_tensor = video_tensor[0].permute(1, 2, 3, 0)
            else:
                video_tensor = video_tensor[0]
        elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4):
            video_tensor = video_tensor.permute(1, 2, 3, 0)
        # If float, assume [-1,1] and normalize to [0,1]
        if video_tensor.is_floating_point():
            video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5
        video_array = video_tensor.float().numpy()
    else:
        video_array = frames
        if hasattr(video_array, "shape") and video_array.ndim == 5:
            video_array = video_array[0]

    # Convert 4D array (frames, H, W, C) to list of frames for export_to_video
    if isinstance(video_array, np.ndarray) and video_array.ndim == 4:
        video_array = list(video_array)

    export_to_video(video_array, str(output_path), fps=args.fps)
    print(f"Saved generated video to {output_path}")


if __name__ == "__main__":
    main()