{
  "$type": "site.standard.document",
  "bskyPostRef": {
    "cid": "bafyreigeprhybde37x4le6kktzt6rpk56q3gnann4pdn3m6ozlsvbxdygy",
    "uri": "at://did:plc:pgryn3ephfd2xgft23qokfzt/app.bsky.feed.post/3mjhi6nqna5q2"
  },
  "path": "/t/fine-tuning-gemma-4-e2b-on-macbook-m3/175228#post_4",
  "publishedAt": "2026-04-14T12:45:54.000Z",
  "site": "https://discuss.huggingface.co",
  "tags": [
    "GitHub",
    "Hugging Face",
    "Google AI for Developers"
  ],
  "textContent": "Oh… Since Transformers and TRL are currently undergoing fairly major renovations, there are just too many potential culprits in this case…\n\n* * *\n\nYour feedback changes the diagnosis a lot.\n\n## The main point\n\nYour problem now looks much less like “bad hyperparameters” and much more like **a mismatch between Gemma 4’s training input contract and TRL’s assistant-only training path**. The strongest clue is your batch: it only contains `input_ids`, `labels`, and `attention_mask`, while there is an active Transformers issue showing that **Gemma 4 text-only fine-tuning may still require`token_type_ids` and `mm_token_type_ids` during training**, with a custom collator and `remove_unused_columns=False` called out as the practical workaround. (GitHub)\n\nThat means your current setup may be doing two things at once: the **loss mask is partially right** , but the **model inputs are still incomplete**. Those are different layers. TRL’s `assistant_only_loss=True` controls which tokens contribute to loss. It does **not** guarantee that every extra tensor Gemma 4 expects at training time is present in the batch. (Hugging Face)\n\n## What your label dump suggests\n\nYour decoded labels are actually useful. They suggest that the assistant answer is being supervised while the user turn is masked out. That is a good sign. It means your custom `{% generation %}` template is not obviously failing at the most basic level. TRL’s docs say assistant-only loss depends on templates that support `{% generation %}` / `{% endgeneration %}` masking, so your workaround is conceptually in the right direction. (Hugging Face)\n\nBut that does **not** prove the template is fully correct. A template can look fine when decoded and still produce a slightly wrong assistant mask, wrong end-of-turn coverage, or a mismatch with how Gemma 4 expects turns to be structured during training. TRL’s own docs make clear that assistant-only loss depends on the template returning the right assistant-token mask, and a current TRL tracking issue exists precisely because many models do not ship training-ready templates with those generation markers. (Hugging Face)\n\n## On your tokenizer observation\n\nI cannot verify the claim that `google/gemma-4-E2B-it` is “for inference only.” The official Google Gemma 4 QLoRA guide still uses the tokenizer from `google/gemma-4-E2B-it` explicitly to get the official template. But your practical observation still makes sense in context: the **official template may be fine for inference or standard conversational formatting** , while still not being enough for **TRL assistant-only SFT** , which specifically needs generation markers for training masks. So your manual template patch is not inherently suspicious. It is just another place where mistakes become easy. (Google AI for Developers)\n\n## What I now think is happening in your case\n\n### 1. Missing Gemma-4-specific batch tensors is the top suspect\n\nThis is now my number one suspect, by a clear margin. Your current batch shape matches the public Gemma 4 issue almost too well. The issue explicitly says Gemma 4 text-only fine-tuning may require both `token_type_ids` and `mm_token_type_ids`, even when your dataset is only `user` and `assistant` text. That is unusual compared with Llama or Qwen, but it is exactly the kind of multimodal-family residue that Gemma 4 currently exposes. (GitHub)\n\n### 2. The `use_cache=False` / checkpointing path is the second suspect\n\nThis is the other big one. There is a recent Transformers issue saying Gemma 4 training with `use_cache=False` can corrupt attention and produce garbage logits, and Transformers 5.5.2 specifically shipped a fix for Gemma 4 in that area. TRL also defaults `gradient_checkpointing=True`, and that often pushes training into the uncached path. So even if your template and labels are mostly fine, the actual forward pass may still be unstable for reasons upstream of your data. (Hugging Face)\n\n### 3. Your template may be “good enough to run” but still not exact\n\nBecause you had to create a custom `{% generation %}` template, I would treat it as _plausible_ , not _trusted_. Gemma-family prompt structure uses `<start_of_turn>user`, `<start_of_turn>model`, and `<end_of_turn>` markers, so the format you showed is not obviously wrong. But the training mask has to align with the assistant response **exactly** , and tiny mismatches there can produce stubbornly bad loss without causing a clean crash. (Hugging Face)\n\n### 4. MPS is likely amplifying the failure, not causing it first\n\nApple’s MPS backend is still documented as beta, and PyTorch’s MPS optimizer docs focus on `float32` and `float16`, while public MPS issues still exist around bf16 support. So your suspicion about Mac precision is reasonable. I just no longer think it is the root cause. I think it is turning a brittle Gemma-4-specific setup into an explosive one. (Google AI for Developers)\n\n## Why TRL alone is not enough here\n\nYou are right that TRL now encourages `assistant_only_loss` rather than older custom completion collators in many setups. But that guidance is about **loss masking** , not about **every model family’s extra forward inputs**. Gemma 4 is precisely the kind of model where those two concerns diverge. So “the collator is deprecated” does **not** really invalidate the need for a small custom collator in your case. It only means the collator should not be the thing deciding the loss region. It can still be the thing adding missing tensors. (Hugging Face)\n\nThat distinction is the key insight for your case:\n\n  * **assistant_only_loss** decides _where loss applies_\n  * **custom collator** can still decide _which tensors the model receives_ (Hugging Face)\n\n\n\n## My recommended fix order\n\n### First fix: add the missing tensors\n\nKeep your conversational dataset and keep `assistant_only_loss=True`, but add a tiny collator that injects zero-filled `token_type_ids` and `mm_token_type_ids`, then set `remove_unused_columns=False`. That is the most directly evidence-backed change you can make right now. (GitHub)\n\n\n    import torch\n    from transformers import default_data_collator\n\n    class Gemma4TextCollator:\n        def __call__(self, features):\n            batch = default_data_collator(features)\n\n            if \"token_type_ids\" not in batch:\n                batch[\"token_type_ids\"] = torch.zeros_like(batch[\"input_ids\"])\n\n            if \"mm_token_type_ids\" not in batch:\n                batch[\"mm_token_type_ids\"] = torch.zeros_like(batch[\"input_ids\"])\n\n            return batch\n\n\n### Second fix: disable the risky training path\n\nFor diagnosis, use:\n\n  * `gradient_checkpointing=False`\n  * do not force `use_cache=False`\n  * `packing=False`\n  * `bf16=False`\n  * `fp16=True` on MPS, or very short fp32 smoke tests only\n  * latest Transformers with the Gemma 4 fixes included (Hugging Face)\n\n\n\n### Third fix: verify the assistant mask mechanically\n\nDo not trust the decoded labels alone. Ask the tokenizer/template path to return the assistant mask and check that it is non-empty and lines up exactly with the answer span. TRL’s masking behavior depends on that. (Hugging Face)\n\n### Fourth fix: shrink the task until it overfits\n\nRun a tiny overfit test on 32–128 samples, sequence length 256–512, batch size 1, low LR. If it still cannot overfit, that is further evidence that the stack is wrong, not the data scale. Google’s own Gemma 4 QLoRA guide already uses small defaults like batch size 1 and max length 512, which supports this conservative debugging strategy. (Google AI for Developers)\n\n## The practical alternative I would seriously consider\n\nFor your case, I would strongly consider switching from conversational `messages` + `assistant_only_loss=True` to **prompt-completion** format for the first stable run.\n\nWhy? Because right now you are depending on a custom Jinja training template to infer the supervised region. Prompt-completion makes the target explicit. That removes one moving part. TRL supports both, and completion-style supervision is simply less fragile when the model family’s chat-template training path is still maturing. (Hugging Face)\n\nThis does **not** mean your current approach is wrong in theory. It means it is fragile in practice, especially on Gemma 4 + TRL + MPS.\n\n## My bottom-line conclusion\n\nBased on everything you shared, my current ranking is:\n\n  1. **Missing`token_type_ids` / `mm_token_type_ids`**\n  2. **`gradient_checkpointing` / `use_cache=False` path**\n  3. **Subtle custom-template mask mismatch**\n  4. **MPS precision instability**\n  5. **Only then: learning-rate or ordinary tuning issues** (GitHub)\n\n\n\nSo my advice is no longer “try different settings.” It is:\n\n  * keep LoRA\n  * keep the simple `messages` dataset for now\n  * add a **Gemma-4-specific collator**\n  * set `remove_unused_columns=False`\n  * turn **checkpointing off**\n  * avoid the uncached path while diagnosing\n  * verify the assistant mask directly\n  * only then return to hyperparameter tuning (GitHub)\n\n",
  "title": "Fine-tuning Gemma-4-E2B on MacBook M3"
}