External Publication
Visit Post

Fine-tuning Gemma-4-E2B on MacBook M3

Hugging Face Forums [Unofficial] April 14, 2026
Source

Oh… Since Transformers and TRL are currently undergoing fairly major renovations, there are just too many potential culprits in this case…


Your feedback changes the diagnosis a lot.

The main point

Your problem now looks much less like “bad hyperparameters” and much more like a mismatch between Gemma 4’s training input contract and TRL’s assistant-only training path. The strongest clue is your batch: it only contains input_ids, labels, and attention_mask, while there is an active Transformers issue showing that Gemma 4 text-only fine-tuning may still requiretoken_type_ids and mm_token_type_ids during training, with a custom collator and remove_unused_columns=False called out as the practical workaround. (GitHub)

That means your current setup may be doing two things at once: the loss mask is partially right , but the model inputs are still incomplete. Those are different layers. TRL’s assistant_only_loss=True controls which tokens contribute to loss. It does not guarantee that every extra tensor Gemma 4 expects at training time is present in the batch. (Hugging Face)

What your label dump suggests

Your decoded labels are actually useful. They suggest that the assistant answer is being supervised while the user turn is masked out. That is a good sign. It means your custom {% generation %} template is not obviously failing at the most basic level. TRL’s docs say assistant-only loss depends on templates that support {% generation %} / {% endgeneration %} masking, so your workaround is conceptually in the right direction. (Hugging Face)

But that does not prove the template is fully correct. A template can look fine when decoded and still produce a slightly wrong assistant mask, wrong end-of-turn coverage, or a mismatch with how Gemma 4 expects turns to be structured during training. TRL’s own docs make clear that assistant-only loss depends on the template returning the right assistant-token mask, and a current TRL tracking issue exists precisely because many models do not ship training-ready templates with those generation markers. (Hugging Face)

On your tokenizer observation

I cannot verify the claim that google/gemma-4-E2B-it is “for inference only.” The official Google Gemma 4 QLoRA guide still uses the tokenizer from google/gemma-4-E2B-it explicitly to get the official template. But your practical observation still makes sense in context: the official template may be fine for inference or standard conversational formatting , while still not being enough for TRL assistant-only SFT , which specifically needs generation markers for training masks. So your manual template patch is not inherently suspicious. It is just another place where mistakes become easy. (Google AI for Developers)

What I now think is happening in your case

1. Missing Gemma-4-specific batch tensors is the top suspect

This is now my number one suspect, by a clear margin. Your current batch shape matches the public Gemma 4 issue almost too well. The issue explicitly says Gemma 4 text-only fine-tuning may require both token_type_ids and mm_token_type_ids, even when your dataset is only user and assistant text. That is unusual compared with Llama or Qwen, but it is exactly the kind of multimodal-family residue that Gemma 4 currently exposes. (GitHub)

2. The use_cache=False / checkpointing path is the second suspect

This is the other big one. There is a recent Transformers issue saying Gemma 4 training with use_cache=False can corrupt attention and produce garbage logits, and Transformers 5.5.2 specifically shipped a fix for Gemma 4 in that area. TRL also defaults gradient_checkpointing=True, and that often pushes training into the uncached path. So even if your template and labels are mostly fine, the actual forward pass may still be unstable for reasons upstream of your data. (Hugging Face)

3. Your template may be “good enough to run” but still not exact

Because you had to create a custom {% generation %} template, I would treat it as plausible , not trusted. Gemma-family prompt structure uses <start_of_turn>user, <start_of_turn>model, and <end_of_turn> markers, so the format you showed is not obviously wrong. But the training mask has to align with the assistant response exactly , and tiny mismatches there can produce stubbornly bad loss without causing a clean crash. (Hugging Face)

4. MPS is likely amplifying the failure, not causing it first

Apple’s MPS backend is still documented as beta, and PyTorch’s MPS optimizer docs focus on float32 and float16, while public MPS issues still exist around bf16 support. So your suspicion about Mac precision is reasonable. I just no longer think it is the root cause. I think it is turning a brittle Gemma-4-specific setup into an explosive one. (Google AI for Developers)

Why TRL alone is not enough here

You are right that TRL now encourages assistant_only_loss rather than older custom completion collators in many setups. But that guidance is about loss masking , not about every model family’s extra forward inputs. Gemma 4 is precisely the kind of model where those two concerns diverge. So “the collator is deprecated” does not really invalidate the need for a small custom collator in your case. It only means the collator should not be the thing deciding the loss region. It can still be the thing adding missing tensors. (Hugging Face)

That distinction is the key insight for your case:

  • assistant_only_loss decides where loss applies
  • custom collator can still decide which tensors the model receives (Hugging Face)

My recommended fix order

First fix: add the missing tensors

Keep your conversational dataset and keep assistant_only_loss=True, but add a tiny collator that injects zero-filled token_type_ids and mm_token_type_ids, then set remove_unused_columns=False. That is the most directly evidence-backed change you can make right now. (GitHub)

import torch
from transformers import default_data_collator

class Gemma4TextCollator:
    def __call__(self, features):
        batch = default_data_collator(features)

        if "token_type_ids" not in batch:
            batch["token_type_ids"] = torch.zeros_like(batch["input_ids"])

        if "mm_token_type_ids" not in batch:
            batch["mm_token_type_ids"] = torch.zeros_like(batch["input_ids"])

        return batch

Second fix: disable the risky training path

For diagnosis, use:

  • gradient_checkpointing=False
  • do not force use_cache=False
  • packing=False
  • bf16=False
  • fp16=True on MPS, or very short fp32 smoke tests only
  • latest Transformers with the Gemma 4 fixes included (Hugging Face)

Third fix: verify the assistant mask mechanically

Do not trust the decoded labels alone. Ask the tokenizer/template path to return the assistant mask and check that it is non-empty and lines up exactly with the answer span. TRL’s masking behavior depends on that. (Hugging Face)

Fourth fix: shrink the task until it overfits

Run a tiny overfit test on 32–128 samples, sequence length 256–512, batch size 1, low LR. If it still cannot overfit, that is further evidence that the stack is wrong, not the data scale. Google’s own Gemma 4 QLoRA guide already uses small defaults like batch size 1 and max length 512, which supports this conservative debugging strategy. (Google AI for Developers)

The practical alternative I would seriously consider

For your case, I would strongly consider switching from conversational messages + assistant_only_loss=True to prompt-completion format for the first stable run.

Why? Because right now you are depending on a custom Jinja training template to infer the supervised region. Prompt-completion makes the target explicit. That removes one moving part. TRL supports both, and completion-style supervision is simply less fragile when the model family’s chat-template training path is still maturing. (Hugging Face)

This does not mean your current approach is wrong in theory. It means it is fragile in practice, especially on Gemma 4 + TRL + MPS.

My bottom-line conclusion

Based on everything you shared, my current ranking is:

  1. Missingtoken_type_ids / mm_token_type_ids
  2. gradient_checkpointing / use_cache=False path
  3. Subtle custom-template mask mismatch
  4. MPS precision instability
  5. Only then: learning-rate or ordinary tuning issues (GitHub)

So my advice is no longer “try different settings.” It is:

  • keep LoRA
  • keep the simple messages dataset for now
  • add a Gemma-4-specific collator
  • set remove_unused_columns=False
  • turn checkpointing off
  • avoid the uncached path while diagnosing
  • verify the assistant mask directly
  • only then return to hyperparameter tuning (GitHub)

Discussion in the ATmosphere

Loading comments...