{
  "$type": "site.standard.document",
  "bskyPostRef": {
    "cid": "bafyreigdfypdqx5gxwiwotq5mlyis4ykz7jspbcfu57nuuk24fvz5bubni",
    "uri": "at://did:plc:pgryn3ephfd2xgft23qokfzt/app.bsky.feed.post/3mgqn36x3px72"
  },
  "path": "/t/recommended-way-of-feeding-pre-computed-embeds-to-generate-of-vlms/174136#post_2",
  "publishedAt": "2026-03-10T23:26:53.000Z",
  "site": "https://discuss.huggingface.co",
  "tags": [
    "Hugging Face",
    "GitHub",
    "vLLM"
  ],
  "textContent": "There’s no smart method I can recommend…\nProbably using a backend that already has a cache, like vLLM, is relatively smart. Below, if using Transformers:\n\n* * *\n\n## Recommendation\n\nFor **Qwen3.5 in Transformers** , the recommended way is **not** to look for a public `image_embeds=` argument on `generate()`. The current public API for `Qwen3_5ForConditionalGeneration` exposes `pixel_values`, `image_grid_thw`, `mm_token_type_ids`, and the generic `inputs_embeds`, but **not** a first-class visual-embedding input. The documented example is still the normal `processor.apply_chat_template(...)` → `model.generate(**inputs)` path. (Hugging Face)\n\nSo the practical recommendation is:\n\n  1. **Precompute at the model’s visual-feature boundary** — the output of `get_image_features(...)` or its Qwen3.5 equivalent internal visual feature extraction boundary.\n  2. **Cache the visual features together with`image_grid_thw`** and the exact model/processor revision.\n  3. **Later, inject those cached visual features through a thin wrapper/subclass** that reproduces the stock multimodal fusion path, instead of trying to pass a bare image-embedding tensor directly to `generate()`. (Hugging Face)\n\n\n\n## Why this is the right abstraction\n\n`inputs_embeds` in the public forward signature is a **generic token-embedding escape hatch** : it means “I already have the full sequence embeddings for the model input.” It is not documented as “visual embeddings go here.” By contrast, the multimodal API explicitly documents `pixel_values`, `image_grid_thw`, and `mm_token_type_ids`, which tells you that Qwen3.5 expects a structured multimodal input contract, not just an opaque tensor. (Hugging Face)\n\nThat contract matters because `image_grid_thw` is part of how the model understands the image feature layout in the language model space. The Qwen3.5 docs describe `image_grid_thw` as the temporal, height, and width of each image’s feature shape in the LLM, and vLLM’s multimodal input docs say explicitly that `image_grid_thw` is needed to calculate positional encoding for Qwen-family image-embedding inputs. (Hugging Face)\n\n## What you should cache\n\nFor **Qwen3.5 specifically** , the safest cache object is:\n\n  * the **LM-ready visual feature tensor(s)** produced at the visual-feature boundary,\n  * `image_grid_thw`,\n  * model ID + exact revision,\n  * processor ID + exact revision,\n  * any processor settings that affect visual tokenization/resolution. (Hugging Face)\n\n\n\nThat is a better boundary than caching raw `pixel_values`, because `pixel_values` is still the input to the expensive visual path. It is also a better boundary than caching the final full-sequence `inputs_embeds`, because full-sequence embeddings are tied much more tightly to one exact prompt layout and one exact generation-preparation path. (Hugging Face)\n\n## Why I would not make raw `inputs_embeds` your main interface\n\nPeople do try this, but it is brittle. A Hugging Face issue on Qwen2-VL shows users manually constructing `inputs_embeds` for multimodal generation and hitting regressions across versions, and another issue shows generic `inputs_embeds` + cache/past-key-value generation problems. Those are not proofs that the approach is impossible, but they are a strong signal that `inputs_embeds` is the **low-level plumbing** , not the most stable public abstraction to build around. (GitHub)\n\nThere is also an architectural reason to be cautious: recent Transformers release notes note that 3D position IDs for vision-language models were unified under a common interface, which means code that manually reconstructs multimodal positions is exactly the kind of code that can get broken by framework changes. (GitHub)\n\n## The best pattern inside Transformers\n\nInside Transformers, I would treat this as a **small adapter layer** :\n\n  * Build the prompt normally with `apply_chat_template(...)`.\n  * Keep `input_ids`, `attention_mask`, `mm_token_type_ids`, `image_grid_thw`.\n  * Replace the model’s visual-feature computation step with your cached features.\n  * Let the rest of the stock multimodal forward/generation path continue unchanged. (Hugging Face)\n\n\n\nConceptually, the adapter looks like this:\n\n\n    cache = build_visual_cache(image)   # precompute once\n    out = generate_from_visual_cache(cache, prompt)\n\n\nInternally, that adapter may use `inputs_embeds`, but the caller should not have to think in terms of raw sequence embeddings.\n\n## The best pattern if you want first-class support\n\nIf you want a **public, supported API** for image embeddings rather than a local wrapper, **vLLM is ahead of Transformers here**. Its multimodal input docs explicitly support image embedding inputs and require the extra Qwen metadata, including `image_grid_thw` for positional encoding. For **Qwen3-VL** , vLLM goes further and states that `image_embeds` should contain both the base image embedding and DeepStack features. (vLLM)\n\nThat is a useful design clue even for Qwen3.5: the right object is usually **not** “a naked image embedding tensor,” but a **typed multimodal visual package plus required metadata**.\n\n## Important distinction: Qwen3.5 vs Qwen3-VL\n\nFor **Qwen3.5** , the public docs expose the multimodal forward inputs but do **not** document DeepStack outputs on the model page. For **Qwen3-VL** , the docs explicitly document `get_image_features(...)`, `deepstack_features`, and the multimodal forward signature. So:\n\n  * for **Qwen3.5** , cache the visual features at the visual-feature boundary plus `image_grid_thw`;\n  * for **Qwen3-VL** , cache **both** the base visual outputs and `deepstack_features`, plus `image_grid_thw`. (Hugging Face)\n\n\n\n## Clear answer to your question\n\nIf your question is:\n\n> “What is the recommended way of providing precomputed `image_embeds` to Qwen3.5?”\n\nThen the answer is:\n\n  * **There is no documented first-class`image_embeds=` entry point on `generate()` for Qwen3.5 in Transformers today.** (Hugging Face)\n  * **The recommended engineering pattern is to cache the model’s visual features plus`image_grid_thw`, then inject them through a thin wrapper/subclass that preserves the stock multimodal path.** (Hugging Face)\n  * **If you want a first-class embedding-input API instead of a wrapper, use a runtime that already exposes it, such as vLLM.** (vLLM)\n\n\n\nThe shortest way to think about it is:\n\n> Do not treat the cache as “image_embeds.pt”.\n>  Treat it as a **reusable multimodal visual package** for a specific Qwen-family model revision. (Hugging Face)\n\n* * *\n\n## **Qwen3.5-specific** template\n\nIt hard-codes the correct patch point as **`model.model.get_image_features`** , because the current Qwen3.5 conditional-generation stack routes multimodal feature extraction through the inner `Qwen3_5Model`, not the outer wrapper. The public Qwen3.5 docs also show the supported multimodal inputs are `pixel_values`, `image_grid_thw`, and `mm_token_type_ids`, not a first-class `image_embeds=` argument. Recent Transformers releases also changed VLM 3D position handling, so keeping the stock path intact is the safer pattern. (Hugging Face)\n\nThe code below validates the intended approach in a strict way:\n\n  * baseline run from the real image,\n  * cache the output of `get_image_features(...)`,\n  * cached run with the **inner** feature method patched,\n  * trap the real vision tower so the run fails if visual recomputation happens. (Hugging Face)\n\n\n\n\n    # Compact Qwen3.5 best-practice template:\n    # precompute visual features once, cache them, then reuse them later for generation.\n    #\n    # Why this version avoids crash:\n    # - For Qwen3.5, patch model.model.get_image_features(...), not the outer model.\n    # - The cached run also traps model.model.visual.forward(...), so it will fail\n    #   if the real vision tower is called by mistake.\n    #\n    # References:\n    # - Qwen3.5 docs:\n    #   https://huggingface.co/docs/transformers/model_doc/qwen3_5\n    # - Transformers releases (3D position-id / generation-path changes):\n    #   https://github.com/huggingface/transformers/releases\n    # - Sample image:\n    #   https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\n    #\n    # deps:\n    #   pip install -U \"torch>=2.3\" \"transformers>=5.3.0\" accelerate pillow\n    #\n    # Notes:\n    # - CUDA: prefers bfloat16 if supported, else float16\n    # - CPU: uses float32\n    # - No argparse\n    # - Low-memory friendly: uses a small public checkpoint and caps image pixels if supported\n    # - This validates the pattern; it is not an official image_embeds= API\n\n    import gc\n    import types\n    from pathlib import Path\n\n    import torch\n    from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration\n\n\n    # ----------------------------\n    # User settings\n    # ----------------------------\n    MODEL_ID = \"Qwen/Qwen3.5-0.8B\"\n    IMAGE_URL = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"\n    CACHE_PATH = Path(\"qwen35_visual_cache.pt\")\n\n    PROMPT = \"Describe the image clearly in 2 short sentences.\"\n    MAX_NEW_TOKENS = 64\n\n\n    # ----------------------------\n    # Helpers\n    # ----------------------------\n    class AttrDict(dict):\n        __getattr__ = dict.get\n        __setattr__ = dict.__setitem__\n\n\n    def pick_device_and_dtype():\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16\n        else:\n            device = torch.device(\"cpu\")\n            dtype = torch.float32\n        return device, dtype\n\n\n    def maybe_make_processor(model_id: str):\n        # Reduce visual tokens on small RAM / VRAM setups if supported by this processor version.\n        try:\n            return AutoProcessor.from_pretrained(\n                model_id,\n                min_pixels=256 * 28 * 28,\n                max_pixels=512 * 28 * 28,\n            )\n        except TypeError:\n            return AutoProcessor.from_pretrained(model_id)\n\n\n    def move_batch_to_device(batch, device):\n        out = {}\n        for k, v in batch.items():\n            out[k] = v.to(device) if torch.is_tensor(v) else v\n        return out\n\n\n    def cpu_clone_tree(x):\n        if x is None:\n            return None\n        if torch.is_tensor(x):\n            return x.detach().cpu().contiguous()\n        if isinstance(x, list):\n            return [cpu_clone_tree(v) for v in x]\n        if isinstance(x, tuple):\n            return tuple(cpu_clone_tree(v) for v in x)\n        if isinstance(x, dict):\n            return {k: cpu_clone_tree(v) for k, v in x.items()}\n        return x\n\n\n    def runtime_cast_tree(x, device, float_dtype):\n        if x is None:\n            return None\n        if torch.is_tensor(x):\n            y = x.to(device)\n            if torch.is_floating_point(y):\n                y = y.to(float_dtype)\n            return y\n        if isinstance(x, list):\n            return [runtime_cast_tree(v, device, float_dtype) for v in x]\n        if isinstance(x, tuple):\n            return tuple(runtime_cast_tree(v, device, float_dtype) for v in x)\n        if isinstance(x, dict):\n            return {k: runtime_cast_tree(v, device, float_dtype) for k, v in x.items()}\n        return x\n\n\n    def total_nbytes(x):\n        if x is None:\n            return 0\n        if torch.is_tensor(x):\n            return x.numel() * x.element_size()\n        if isinstance(x, (list, tuple)):\n            return sum(total_nbytes(v) for v in x)\n        if isinstance(x, dict):\n            return sum(total_nbytes(v) for v in x.values())\n        return 0\n\n\n    def format_bytes(n):\n        units = [\"B\", \"KB\", \"MB\", \"GB\", \"TB\"]\n        n = float(n)\n        i = 0\n        while n >= 1024 and i < len(units) - 1:\n            n /= 1024.0\n            i += 1\n        return f\"{n:.2f} {units[i]}\"\n\n\n    def build_messages(prompt_text, image_url):\n        return [\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"image\", \"image\": image_url},\n                    {\"type\": \"text\", \"text\": prompt_text},\n                ],\n            }\n        ]\n\n\n    def build_inputs(processor, messages):\n        batch = processor.apply_chat_template(\n            messages,\n            tokenize=True,\n            add_generation_prompt=True,\n            return_dict=True,\n            return_tensors=\"pt\",\n        )\n        batch.pop(\"token_type_ids\", None)\n        return batch\n\n\n    def decode_new_tokens(processor, prompt_input_ids, generated_ids):\n        trimmed = [\n            out_ids[len(in_ids):]\n            for in_ids, out_ids in zip(prompt_input_ids, generated_ids)\n        ]\n        return processor.batch_decode(\n            trimmed,\n            skip_special_tokens=True,\n            clean_up_tokenization_spaces=False,\n        )[0].strip()\n\n\n    def run_generate(model, processor, inputs, label):\n        with torch.inference_mode():\n            generated_ids = model.generate(\n                **inputs,\n                max_new_tokens=MAX_NEW_TOKENS,\n                do_sample=False,\n                use_cache=True,\n            )\n        text = decode_new_tokens(processor, inputs[\"input_ids\"], generated_ids)\n        print(f\"\\n[{label}]\")\n        print(text)\n        return text\n\n\n    # ----------------------------\n    # Qwen3.5-specific cache boundary\n    # ----------------------------\n    def make_visual_cache(model, processor_inputs, cache_path: Path):\n        \"\"\"\n        Qwen3.5-specific:\n        the multimodal feature extraction lives on model.model.get_image_features(...).\n        \"\"\"\n        owner = model.model\n\n        with torch.inference_mode():\n            image_outputs = owner.get_image_features(\n                pixel_values=processor_inputs[\"pixel_values\"],\n                image_grid_thw=processor_inputs[\"image_grid_thw\"],\n                return_dict=True,\n            )\n\n        cache = {\n            \"model_id\": MODEL_ID,\n            \"image_grid_thw\": cpu_clone_tree(processor_inputs[\"image_grid_thw\"]),\n            \"prompt_skeleton\": {\n                \"input_ids\": cpu_clone_tree(processor_inputs[\"input_ids\"]),\n                \"attention_mask\": cpu_clone_tree(processor_inputs[\"attention_mask\"]),\n                \"mm_token_type_ids\": cpu_clone_tree(processor_inputs.get(\"mm_token_type_ids\")),\n            },\n            # Keep the full visual output object as a plain dict.\n            \"visual_outputs\": cpu_clone_tree(dict(image_outputs)),\n        }\n\n        torch.save(cache, cache_path)\n\n        print(\"\\n[cache stats]\")\n        print(\"pixel_values bytes :\", format_bytes(total_nbytes(processor_inputs[\"pixel_values\"])))\n        print(\"cached visual bytes:\", format_bytes(total_nbytes(cache[\"visual_outputs\"])))\n        print(\"cache file         :\", str(cache_path.resolve()))\n\n        return cache\n\n\n    # ----------------------------\n    # Qwen3.5-specific patching\n    # ----------------------------\n    def install_cached_visual_patch(model, cache, device, float_dtype):\n        \"\"\"\n        Patch the INNER Qwen3.5 model, not the outer generation wrapper.\n        Also trap the real vision tower to prove cached reuse is actually happening.\n        \"\"\"\n        owner = model.model  # <-- this is the important fix\n        original_get_image_features = owner.get_image_features\n        original_visual_forward = owner.visual.forward\n\n        patch_state = {\n            \"patched_calls\": 0,\n            \"real_visual_calls\": 0,\n        }\n\n        def patched_get_image_features(self, pixel_values=None, image_grid_thw=None, **kwargs):\n            patch_state[\"patched_calls\"] += 1\n            return AttrDict(runtime_cast_tree(cache[\"visual_outputs\"], device, float_dtype))\n\n        def trapped_visual_forward(self, *args, **kwargs):\n            patch_state[\"real_visual_calls\"] += 1\n            raise RuntimeError(\n                \"Real Qwen3.5 vision tower was called during cached run. \"\n                \"The cached path did not bypass visual recomputation.\"\n            )\n\n        owner.get_image_features = types.MethodType(patched_get_image_features, owner)\n        owner.visual.forward = types.MethodType(trapped_visual_forward, owner.visual)\n        return owner, original_get_image_features, original_visual_forward, patch_state\n\n\n    def restore_cached_visual_patch(owner, original_get_image_features, original_visual_forward):\n        owner.get_image_features = original_get_image_features\n        owner.visual.forward = original_visual_forward\n\n\n    def make_cached_inputs(cache, device, float_dtype):\n        \"\"\"\n        Keep the multimodal branch active with a tiny non-empty sentinel pixel_values tensor.\n        The patched model.model.get_image_features(...) ignores it completely.\n        \"\"\"\n        prompt = cache[\"prompt_skeleton\"]\n\n        out = {\n            \"input_ids\": prompt[\"input_ids\"].to(device),\n            \"attention_mask\": prompt[\"attention_mask\"].to(device),\n            \"image_grid_thw\": cache[\"image_grid_thw\"].to(device),\n            \"pixel_values\": torch.zeros((1,), device=device, dtype=float_dtype),\n        }\n\n        if prompt.get(\"mm_token_type_ids\") is not None:\n            out[\"mm_token_type_ids\"] = prompt[\"mm_token_type_ids\"].to(device)\n\n        return out\n\n\n    # ----------------------------\n    # Main\n    # ----------------------------\n    def main():\n        device, dtype = pick_device_and_dtype()\n\n        print(\"[runtime]\")\n        print(\"device:\", device)\n        print(\"dtype :\", dtype)\n\n        processor = maybe_make_processor(MODEL_ID)\n        model = Qwen3_5ForConditionalGeneration.from_pretrained(\n            MODEL_ID,\n            torch_dtype=dtype,\n            low_cpu_mem_usage=True,\n            attn_implementation=\"sdpa\",\n        ).to(device)\n        model.eval()\n\n        # 1) Baseline run\n        messages = build_messages(PROMPT, IMAGE_URL)\n        inputs = move_batch_to_device(build_inputs(processor, messages), device)\n\n        print(\"\\n[input keys]\")\n        print(sorted(inputs.keys()))\n        print(\"image_grid_thw:\", inputs[\"image_grid_thw\"].tolist())\n\n        baseline_text = run_generate(model, processor, inputs, \"baseline / normal image path\")\n\n        # 2) Build cache at the visual-feature boundary\n        _ = make_visual_cache(model, inputs, CACHE_PATH)\n\n        # Simulate \"later\"\n        if \"pixel_values\" in inputs:\n            del inputs[\"pixel_values\"]\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\n        # 3) Cached run\n        cache = torch.load(CACHE_PATH, map_location=\"cpu\")\n        if cache[\"model_id\"] != MODEL_ID:\n            raise ValueError(f\"Cache was created for {cache['model_id']}, current model is {MODEL_ID}.\")\n\n        owner, original_get_image_features, original_visual_forward, patch_state = install_cached_visual_patch(\n            model=model,\n            cache=cache,\n            device=device,\n            float_dtype=dtype,\n        )\n\n        try:\n            cached_inputs = make_cached_inputs(cache, device, dtype)\n            cached_text = run_generate(model, processor, cached_inputs, \"cached visual path\")\n        finally:\n            restore_cached_visual_patch(owner, original_get_image_features, original_visual_forward)\n\n        # 4) Validation\n        print(\"\\n[validation]\")\n        print(\"patched get_image_features calls:\", patch_state[\"patched_calls\"])\n        print(\"real visual.forward calls      :\", patch_state[\"real_visual_calls\"])\n        print(\"baseline == cached             :\", baseline_text == cached_text)\n        print(\"cache file                     :\", str(CACHE_PATH.resolve()))\n\n        if patch_state[\"patched_calls\"] < 1:\n            raise RuntimeError(\"Patched model.model.get_image_features was never called.\")\n\n        if patch_state[\"real_visual_calls\"] != 0:\n            raise RuntimeError(\"The real Qwen3.5 vision tower ran during the cached path.\")\n\n        print(\"\\n[result]\")\n        if baseline_text == cached_text:\n            print(\"Success: cached visual features reproduced the same output.\")\n        else:\n            print(\"Cached path succeeded and bypassed visual recomputation.\")\n            print(\"Text differs from baseline, which can still happen across kernels/dtypes/versions.\")\n\n\n    if __name__ == \"__main__\":\n        main()\n",
  "title": "Recommended way of feeding pre-computed embeds to generate of VLMs"
}