跳到内容

文本到图像

来源 https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/text_to_image

此文件夹提供了几个入口点,用于使用 vLLM-Omni 试验 Qwen/Qwen-ImageTongyi-MAI/Z-Image-Turbo

  • text_to_image.py:用于单张图像生成和高级选项的命令行脚本。
  • web_demo.py:用于交互式提示/种子/CFG 探索的轻量级 Gradio UI。

基本用法

from vllm_omni.entrypoints.omni import Omni

if __name__ == "__main__":
    omni = Omni(model="Qwen/Qwen-Image")
    prompt = "a cup of coffee on the table"
    images = omni.generate(prompt)
    images[0].save("coffee.png")

本地 CLI 用法

python text_to_image.py \
  --model Tongyi-MAI/Z-Image-Turbo \
  --prompt "a cup of coffee on the table" \
  --seed 42 \
  --cfg_scale 4.0 \
  --num_images_per_prompt 1 \
  --num_inference_steps 50 \
  --height 1024 \
  --width 1024 \
  --output outputs/coffee.png

关键参数

  • --prompt:文本描述(字符串)。
  • --seed:用于确定性采样的整数种子。
  • --cfg_scale:真实的 CFG 尺度(模型特定的引导强度)。
  • --num_images_per_prompt:每个提示生成的图像数量(保存为 outputoutput_1 等)。
  • --num_inference_steps:扩散采样步数(步数越多 = 质量越高,速度越慢)。
  • --height/--width:输出分辨率(默认为 1024x1024)。
  • --output:保存生成 PNG 的路径。

ℹ️ Qwen-Image 目前提供 1328x13281664x928928x16641472x11401140x14721584x10561056x1584 的最佳效果预设。请相应调整 --height/--width 以获得最可靠的结果。

Web UI 演示

启动 Gradio 演示

python gradio_demo.py --port 7862

然后,在您的本地浏览器中打开 https://:7862/,即可与 Web UI 交互。

示例材料

gradio_demo.py
import argparse
from functools import lru_cache

import gradio as gr
import torch

from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type, is_npu

ASPECT_RATIOS: dict[str, tuple[int, int]] = {
    "1:1": (1328, 1328),
    "16:9": (1664, 928),
    "9:16": (928, 1664),
    "4:3": (1472, 1140),
    "3:4": (1140, 1472),
    "3:2": (1584, 1056),
    "2:3": (1056, 1584),
}
ASPECT_RATIO_CHOICES = [f"{ratio} ({w}x{h})" for ratio, (w, h) in ASPECT_RATIOS.items()]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Gradio demo for Qwen-Image offline inference.")
    parser.add_argument("--model", default="Qwen/Qwen-Image", help="Diffusion model name or local path.")
    parser.add_argument(
        "--height",
        type=int,
        default=1328,
        help="Default image height (must match one of the supported presets).",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=1328,
        help="Default image width (must match one of the supported presets).",
    )
    parser.add_argument("--default-prompt", default="a cup of coffee on the table", help="Initial prompt shown in UI.")
    parser.add_argument("--default-seed", type=int, default=42, help="Initial seed shown in UI.")
    parser.add_argument("--default-cfg-scale", type=float, default=4.0, help="Initial CFG scale shown in UI.")
    parser.add_argument(
        "--num_inference_steps",
        type=int,
        default=50,
        help="Default number of denoising steps shown in the UI.",
    )
    parser.add_argument("--ip", default="127.0.0.1", help="Host/IP for Gradio `launch`.")
    parser.add_argument("--port", type=int, default=7862, help="Port for Gradio `launch`.")
    parser.add_argument("--share", action="store_true", help="Share the Gradio demo publicly.")
    args = parser.parse_args()
    args.aspect_ratio_label = next(
        (ratio for ratio, dims in ASPECT_RATIOS.items() if dims == (args.width, args.height)),
        None,
    )
    if args.aspect_ratio_label is None:
        supported = ", ".join(f"{ratio} ({w}x{h})" for ratio, (w, h) in ASPECT_RATIOS.items())
        parser.error(f"Unsupported resolution {args.width}x{args.height}. Please pick one of: {supported}.")
    return args


@lru_cache(maxsize=1)
def get_omni(model_name: str) -> Omni:
    # Enable VAE memory optimizations on NPU
    vae_use_slicing = is_npu()
    vae_use_tiling = is_npu()
    return Omni(
        model=model_name,
        vae_use_slicing=vae_use_slicing,
        vae_use_tiling=vae_use_tiling,
    )


def build_demo(args: argparse.Namespace) -> gr.Blocks:
    device = detect_device_type()
    omni = get_omni(args.model)

    def run_inference(
        prompt: str,
        seed_value: float,
        cfg_scale_value: float,
        resolution_choice: str,
        num_steps_value: float,
        num_images_choice: float,
    ):
        if not prompt or not prompt.strip():
            raise gr.Error("Please enter a non-empty prompt.")
        ratio_label = resolution_choice.split(" ", 1)[0]
        if ratio_label not in ASPECT_RATIOS:
            raise gr.Error(f"Unsupported aspect ratio: {ratio_label}")
        width, height = ASPECT_RATIOS[ratio_label]
        try:
            seed = int(seed_value)
            num_steps = int(num_steps_value)
            num_images = int(num_images_choice)
        except (TypeError, ValueError) as exc:
            raise gr.Error("Seed, inference steps, and number of images must be valid integers.") from exc
        if num_steps <= 0:
            raise gr.Error("Inference steps must be a positive integer.")
        if num_images not in {1, 2, 3, 4}:
            raise gr.Error("Number of images must be 1, 2, 3, or 4.")
        generator = torch.Generator(device=device).manual_seed(seed)
        images = omni.generate(
            prompt.strip(),
            height=height,
            width=width,
            generator=generator,
            true_cfg_scale=float(cfg_scale_value),
            num_inference_steps=num_steps,
            num_outputs_per_prompt=num_images,
        )
        return [img for img in images[:num_images]]

    with gr.Blocks(
        title="vLLM-Omni Web Serving Demo",
        css="""
        /* Left column button width */
        .left-column button {
            width: 100%;
        }
        /* Right preview area: fixed height, hide unnecessary buttons */
        .fixed-image {
            height: 660px;
            display: flex;
            flex-direction: column;
            justify-content: center;
            align-items: center;
        }
        .fixed-image .duplicate-button,
        .fixed-image .svelte-drgfj2 {
            display: none !important;
        }
        /* Gallery container: fill available space and center content */
        #image-gallery {
            width: 100%;
            height: 100%;
            display: flex;
            align-items: center;
            justify-content: center;
        }
        /* Gallery grid: center horizontally and vertically, set gap */
        #image-gallery .grid {
            display: flex;
            flex-wrap: wrap;
            justify-content: center;
            align-items: center;
            align-content: center;
            gap: 16px;
            width: 100%;
            height: 100%;
        }
        /* Gallery grid items: center content */
        #image-gallery .grid > div {
            display: flex;
            align-items: center;
            justify-content: center;
        }
        /* Gallery images: limit max height, maintain aspect ratio */
        .fixed-image img {
            max-height: 660px !important;
            width: auto !important;
            object-fit: contain;
        }
        """,
    ) as demo:
        gr.Markdown("# vLLM-Omni Web Serving Demo")
        gr.Markdown(f"**Model:** {args.model}")

        with gr.Row():
            with gr.Column(scale=1, elem_classes="left-column"):
                prompt_input = gr.Textbox(
                    label="Prompt",
                    value=args.default_prompt,
                    placeholder="Describe the image you want to generate...",
                    lines=5,
                )
                seed_input = gr.Number(label="Seed", value=args.default_seed, precision=0)
                cfg_input = gr.Number(label="CFG Scale", value=args.default_cfg_scale)
                steps_input = gr.Number(
                    label="Inference Steps",
                    value=args.num_inference_steps,
                    precision=0,
                    minimum=1,
                )
                aspect_dropdown = gr.Dropdown(
                    label="Aspect Ratio (W:H)",
                    choices=ASPECT_RATIO_CHOICES,
                    value=f"{args.aspect_ratio_label} ({ASPECT_RATIOS[args.aspect_ratio_label][0]}x{ASPECT_RATIOS[args.aspect_ratio_label][1]})",
                )
                num_images = gr.Dropdown(
                    label="Number of images",
                    choices=["1", "2", "3", "4"],
                    value="1",
                )
                generate_btn = gr.Button("Generate", variant="primary")
            with gr.Column(scale=2, elem_classes="fixed-image"):
                gallery = gr.Gallery(
                    label="Preview",
                    columns=2,
                    rows=2,
                    height=660,
                    allow_preview=True,
                    show_label=True,
                    elem_id="image-gallery",
                )

        generate_btn.click(
            fn=run_inference,
            inputs=[prompt_input, seed_input, cfg_input, aspect_dropdown, steps_input, num_images],
            outputs=gallery,
        )

    return demo


def main():
    args = parse_args()
    demo = build_demo(args)
    demo.launch(server_name=args.ip, server_port=args.port, share=args.share)


if __name__ == "__main__":
    main()
text_to_image.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import time
from pathlib import Path

import torch

from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type, is_npu


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Generate an image with Qwen-Image.")
    parser.add_argument(
        "--model",
        default="Qwen/Qwen-Image",
        help="Diffusion model name or local path. Supported models: Qwen/Qwen-Image, Tongyi-MAI/Z-Image-Turbo",
    )
    parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic results.")
    parser.add_argument(
        "--cfg_scale",
        type=float,
        default=4.0,
        help="True classifier-free guidance scale specific to Qwen-Image.",
    )
    parser.add_argument("--height", type=int, default=1024, help="Height of generated image.")
    parser.add_argument("--width", type=int, default=1024, help="Width of generated image.")
    parser.add_argument(
        "--output",
        type=str,
        default="qwen_image_output.png",
        help="Path to save the generated image (PNG).",
    )
    parser.add_argument(
        "--num_images_per_prompt",
        type=int,
        default=1,
        help="Number of images to generate for the given prompt.",
    )
    parser.add_argument(
        "--num_inference_steps",
        type=int,
        default=50,
        help="Number of denoising steps for the diffusion sampler.",
    )
    parser.add_argument(
        "--cache_backend",
        type=str,
        default=None,
        choices=["cache_dit", "tea_cache"],
        help=(
            "Cache backend to use for acceleration. "
            "Options: 'cache_dit' (DBCache + SCM + TaylorSeer), 'tea_cache' (Timestep Embedding Aware Cache). "
            "Default: None (no cache acceleration)."
        ),
    )
    parser.add_argument(
        "--ulysses_degree",
        type=int,
        default=1,
        help="Number of GPUs used for ulysses sequence parallelism.",
    )

    return parser.parse_args()


def main():
    args = parse_args()
    device = detect_device_type()
    generator = torch.Generator(device=device).manual_seed(args.seed)

    # Enable VAE memory optimizations on NPU
    vae_use_slicing = is_npu()
    vae_use_tiling = is_npu()

    # Configure cache based on backend type
    cache_config = None
    if args.cache_backend == "cache_dit":
        # cache-dit configuration: Hybrid DBCache + SCM + TaylorSeer
        # All parameters marked with [cache-dit only] in DiffusionCacheConfig
        cache_config = {
            # DBCache parameters [cache-dit only]
            "Fn_compute_blocks": 1,  # Optimized for single-transformer models
            "Bn_compute_blocks": 0,  # Number of backward compute blocks
            "max_warmup_steps": 4,  # Maximum warmup steps (works for few-step models)
            "residual_diff_threshold": 0.24,  # Higher threshold for more aggressive caching
            "max_continuous_cached_steps": 3,  # Limit to prevent precision degradation
            # TaylorSeer parameters [cache-dit only]
            "enable_taylorseer": False,  # Disabled by default (not suitable for few-step models)
            "taylorseer_order": 1,  # TaylorSeer polynomial order
            # SCM (Step Computation Masking) parameters [cache-dit only]
            "scm_steps_mask_policy": None,  # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra"
            "scm_steps_policy": "dynamic",  # SCM steps policy: "dynamic" or "static"
        }
    elif args.cache_backend == "tea_cache":
        # TeaCache configuration
        # All parameters marked with [tea_cache only] in DiffusionCacheConfig
        cache_config = {
            # TeaCache parameters [tea_cache only]
            "rel_l1_thresh": 0.2,  # Threshold for accumulated relative L1 distance
            # Note: coefficients will use model-specific defaults based on model_type
            #       (e.g., QwenImagePipeline or FluxPipeline)
        }

    parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree)
    omni = Omni(
        model=args.model,
        vae_use_slicing=vae_use_slicing,
        vae_use_tiling=vae_use_tiling,
        cache_backend=args.cache_backend,
        cache_config=cache_config,
        parallel_config=parallel_config,
    )

    # Time profiling for generation
    print(f"\n{'=' * 60}")
    print("Generation Configuration:")
    print(f"  Model: {args.model}")
    print(f"  Inference steps: {args.num_inference_steps}")
    print(f"  Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}")
    print(f"  Parallel configuration: ulysses_degree={args.ulysses_degree}")
    print(f"  Image size: {args.width}x{args.height}")
    print(f"{'=' * 60}\n")

    generation_start = time.perf_counter()
    images = omni.generate(
        args.prompt,
        height=args.height,
        width=args.width,
        generator=generator,
        true_cfg_scale=args.cfg_scale,
        num_inference_steps=args.num_inference_steps,
        num_outputs_per_prompt=args.num_images_per_prompt,
    )
    generation_end = time.perf_counter()
    generation_time = generation_end - generation_start

    # Print profiling results
    print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)")

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    suffix = output_path.suffix or ".png"
    stem = output_path.stem or "qwen_image_output"
    if args.num_images_per_prompt <= 1:
        images[0].save(output_path)
        print(f"Saved generated image to {output_path}")
    else:
        for idx, img in enumerate(images):
            save_path = output_path.parent / f"{stem}_{idx}{suffix}"
            img.save(save_path)
            print(f"Saved generated image to {save_path}")


if __name__ == "__main__":
    main()