跳转至

模型接入指南

本文档给出把一个 Diffusers 风格新模型接入 omnirt 的最短路径。

1. 新增 pipeline 类

src/omnirt/models/<model_name>/ 下创建 pipeline.py,继承 BasePipeline 并实现五阶段:

from omnirt.core.base_pipeline import BasePipeline
from omnirt.core.registry import ModelCapabilities, register_model


@register_model(
    id="my-model",
    task="text2image",
    default_backend="auto",
    resource_hint={"min_vram_gb": 12, "dtype": "fp16"},
    capabilities=ModelCapabilities(
        required_inputs=("prompt",),
        optional_inputs=("negative_prompt",),
        supported_config=(
            "model_path", "scheduler", "height", "width",
            "num_inference_steps", "guidance_scale", "seed", "dtype", "output_dir",
        ),
        default_config={
            "scheduler": "euler-discrete",
            "height": 1024, "width": 1024, "dtype": "fp16",
        },
        supported_schedulers=("euler-discrete", "ddim", "dpm-solver", "euler-ancestral"),
        adapter_kinds=("lora",),
        artifact_kind="image",
        maturity="experimental",   # experimental | beta | stable
        summary="One-sentence description that shows up in `omnirt models`.",
        example='omnirt generate --task text2image --model my-model --prompt "..." --backend cuda',
    ),
)
class MyPipeline(BasePipeline):
    def prepare_conditions(self, req): ...
    def prepare_latents(self, req, conditions): ...
    def denoise_loop(self, latents, conditions, config): ...
    def decode(self, latents): ...
    def export(self, raw, req): ...

ModelCapabilities 字段速览

capabilities 字段把模型的可用配置暴露给 omnirt modelsomnirt validate 和 Python API:

字段 用途
required_inputs inputs 里必须出现的 key;缺失时 validate() 直接报错
optional_inputs inputs 里允许但非必须的 key
supported_config config 里允许的 key;白名单外的值会被标记为 warning
default_config 用户未显式指定时的默认值(height/width/dtype 等)
supported_schedulers 本 pipeline 真正测过的 scheduler id;--scheduler 超出范围时报警
adapter_kinds 支持的 adapter 种类(目前仅 "lora"
artifact_kind "image" / "video";决定导出器选择
maturity experimental / beta / stable;CLI models 列表会显示
summary / example omnirt models 和文档用的简介和示例命令

完整定义见 src/omnirt/core/registry.py

2. 复用后端契约

通过 self.runtime.wrap_module(submodule, tag="unet") 包装对后端敏感的子模块。这样新模型无需写专属后端代码,就能拿到 compile / override / eager 三级回退,并在 RunReport.backend_timeline 里有每步的记录。

3. 注册 alias(可选)

同一个 pipeline 可以同时暴露多个 id。把别名版本用 alias_of 标出,它会在 omnirt models --format markdown 里被放进独立的「Aliases」表,不会污染主清单:

@register_model(
    id="flux2.dev",
    task="text2image",
    capabilities=ModelCapabilities(summary="Flux 2 dev text-to-image pipeline."),
)
@register_model(
    id="flux2-dev",
    task="text2image",
    capabilities=ModelCapabilities(
        summary="Flux 2 dev text-to-image pipeline.",
        alias_of="flux2.dev",
    ),
)
class Flux2Pipeline(BasePipeline):
    ...

装饰器可以堆叠,每次调用都会把一份条目写进 registry;get_model("flux2-dev")get_model("flux2.dev") 解析到同一个类。

4. 规格与约定

  • 权重:仅加载 safetensors
  • Adapters:在 __init__ 校验路径存在,在 prepare_conditions 阶段真正 materialize pipeline 后应用一次;单文件 adapter 既可以来自本地,也可以来自 Hugging Face
  • 产物:export 返回带完整文件路径的 Artifact,CLI 依赖 Artifact.path 回显

5. 把模型接入注册入口

src/omnirt/models/__init__.py_BUILTIN_MODEL_IDS 里加上你的 id,并在 ensure_registered() 里导入你的 pipeline 模块。这样 omnirt models 启动时会保证你的 pipeline 被加载。

6. 补测试

  • 单测:用 tests/unit/test_sd15_pipeline.py 作为模板,靠 fake runtime + fake Diffusers 对象跑完 BasePipeline.run()
  • Integration:tests/integration/test_<model>_{cuda,ascend}.py,依赖硬件的 case 在 conftest.py 里有自动 skip
  • Parity(可选):如果要参与跨后端验证,用 tests/parity/test_parity.py 的 latent 统计 / PSNR helper

7. 文档

  • 新增模型不需要手工改 docs/_generated/models.md —— 跑一遍 python scripts/generate_models_doc.py 即可,ModelCapabilities.summary 会自动出现在表里
  • 如果模型有特殊部署要求或已知限制,写进 支持状态 的「部分支持」或「已接入但仍待真机 smoke」小节

参考实现

src/omnirt/models/sd15/src/omnirt/models/sdxl/src/omnirt/models/svd/ 是当前覆盖面最完整的参考:分别给出 SD1.5、SDXL 和 SVD 的完整五阶段实现,并且同时展示了 fp16 variant 检测、LoRA 装配、视频帧导出等通用技巧。