Looking for guidance. Trying to create a model with TrOCR's encoder + Google's mT5 multilingual decoder but model fails to overfit on a single data sample
Hmm… Hypothesis that the problem is occurring on the decoder side:
My overall conclusion
I do not think you should change the tokenizer or abandon the current TrOCR-encoder + mT5-decoder setup yet.
I think the current evidence says something more specific:
- the architecture can work, because some targets overfit correctly;
- the decoder is still under-adapted on harder lines;
- and the tokens like
<extra_id_0>are not random accidents — they are a very specific sign that the T5-family decoder is falling back to its pretraining behavior when OCR grounding is weak. T5-family tokenizers includeextra_idsspecial tokens, and the original T5 pretraining objective uses sentinel tokens as part of span corruption. (huggingface.co)
So my main recommendation is:
Keep the current setup, but change how you adapt and evaluate the decoder.
Why <extra_id_0> appears at all
This is the first thing to understand.
<extra_id_0> is a built-in special token in the T5 and mT5 tokenizer family. It is not some random OCR artifact. T5 was pretrained with a denoising objective that literally teaches the model to emit sentinel tokens like <extra_id_0> when reconstructing masked spans. That means these tokens have very strong pretrained priors. (arxiv.org)
So when your OCR model is uncertain, what can happen is:
- the image signal is not strong enough to dominate,
- the decoder falls back to familiar pretrained behavior,
- sentinel tokens and repetitive continuations leak into generation.
That is why your outputs look like:
<extra_id_0>- then repeated “लिए”
- then repeated “और”
- and other locally high-probability continuations
This is a decoder grounding problem , not a token-coverage problem.
Why some lines fit and others fail
This is the second key idea.
If the only issue were random initialization, you would mostly see run-to-run differences:
- one run works,
- another run does not.
But what you are seeing is also line-to-line variation:
- some target texts overfit nicely,
- others collapse badly.
That means there are multiple causes at once.
Cause 1: random bridge initialization
Hugging Face’s encoder-decoder docs explicitly note that in warm-started hybrids, the decoder-side cross-attention can be randomly initialized and must be fine-tuned downstream. So yes, some instability is expected. (huggingface.co)
Cause 2: target difficulty is uneven
Some lines are easier:
- shorter,
- cleaner,
- more common vocabulary,
- fewer punctuation marks,
- easier crops,
- more printed than handwritten.
Some are harder:
- longer,
- more punctuation,
- noisier handwriting,
- denser ligatures,
- rarer word combinations.
The hard lines require stronger and more stable image grounding. So they expose decoder weakness faster.
Cause 3: your current trainable slice is still too narrow
This is the main practical issue.
Your two freezing strategies both let the model learn some bridge behavior, but they do not give the decoder enough freedom to fully reshape sequence generation for hard OCR lines.
That is why the model can sometimes stick closely to the target and sometimes fail badly.
So my interpretation is:
random initialization contributes, but the bigger story is under-adaptation of the decoder in a tiny-data regime.
Why your two freezing strategies behave this way
Strategy 1
Train only:
EncDecAttentionlm_headshared- projection
This helps the model learn:
- how to inject image features into the decoder,
- and how to map decoder states into output tokens.
That is often enough for easy examples.
But it does not fully change the decoder’s internal sequence dynamics.
So if the image signal is weak, the decoder still falls back to pretrained behavior.
Strategy 2
Add:
DenseReluDense
This is broader and better than Strategy 1.
But it still leaves other important parts constrained, especially self-attention-driven sequence behavior.
So it can still fail on harder examples.
That is why both strategies can show “sometimes good, sometimes terrible” behavior.
They are not wrong. They are just not broad enough yet for the hard lines.
My recommended solutions, in order
Solution 1. Keep the current architecture
Do not switch tokenizer. Do not switch away from mT5 yet. Do not switch away from the TrOCR encoder yet.
Reason:
- one-sample overfit success proves the wiring can work;
<extra_id_0>means decoder fallback, not missing Hindi token support. (huggingface.co)
This is the highest-confidence recommendation.
Solution 2. Use a broader decoder-side adaptation strategy
This is the most important practical change.
Recommended next freeze schedule
Freeze:
- entire encoder
Train:
enc_to_dec_proj- all
EncDecAttentionlayers lm_headshared- all parameters in the last 2 decoder blocks
This is better than both of your current strategies because it gives the decoder more freedom to change:
- sequence behavior,
- grounding behavior,
- and output token dynamics.
I would use this as the next main training strategy.
Why this makes sense:
- the encoder already gives usable image features;
- the fragile part is still the decoder-side bridge and generation;
- Hugging Face’s docs already point to cross-attention as the new component that often needs fine-tuning in warm-started hybrids. (huggingface.co)
What I would not do
Do not unfreeze the encoder yet.
That is too early for your data size and not where the failure signal is pointing.
Solution 3. Suppress sentinel tokens during validation and inference
This is a very useful guardrail.
Hugging Face generation utilities support bad_words_ids, which lets you block specific tokens or token sequences during generation. Since <extra_id_n> tokens should never be valid OCR output for your task, you can suppress them during validation and inference. (huggingface.co)
Example idea:
extra_tokens = [f"<extra_id_{i}>" for i in range(100)]
bad_words_ids = tokenizer(extra_tokens, add_special_tokens=False).input_ids
generated_ids = model.generate(
pixel_values=pixel_values,
max_new_tokens=max_new_tokens,
num_beams=1,
do_sample=False,
bad_words_ids=bad_words_ids,
)
Important caution:
- this is not the real fix,
- it is a guardrail.
It prevents the most obviously invalid decoder fallback behavior from polluting your evaluation, while you keep working on the actual training problem.
Solution 4. Split your tests into easy lines and hard lines
Right now your model feels “unpredictable” because you are mentally averaging together different difficulty levels.
Do this instead:
Easy probe set
Use lines that are:
- shorter,
- cleaner,
- more printed,
- less punctuation-heavy,
- more common vocabulary.
Hard probe set
Use lines that are:
- longer,
- more punctuation-heavy,
- noisier handwriting,
- more complex Devanagari forms,
- more unusual vocabulary.
Then run the same overfit test on both.
This will tell you much more than one mixed impression.
If easy lines fit but hard lines do not, then the explanation is not “just random init.” It is:
- random init,
- plus hard-target difficulty,
- plus decoder under-adaptation.
Solution 5. Add three diagnostics
These three diagnostics will make your debugging much clearer.
A. Sentinel-token rate
Track how often predictions contain <extra_id_0> or any <extra_id_n>.
This tells you whether the decoder is still falling back to T5 pretraining behavior.
B. Length ratio
Track:
len(prediction) / len(reference)
If this ratio explodes, repetition and EOS failure are dominating.
C. Target token length
Track tokenized target length for each line.
Hard examples often cluster here.
These three numbers will be more informative than loss alone.
Solution 6. Tighten generation length
A flat max_new_tokens=64 is probably too blunt.
My recommendation is:
- compute the 95th percentile target token length in your dataset,
- then set
max_new_tokens = p95 + 4.
Why:
- long ceilings give unstable models more room to loop,
- shorter, data-driven ceilings reduce runaway repetition.
This is a practical recommendation based on the failure pattern you are seeing.
Solution 7. Move to LoRA if the above still fails
If the broader decoder-side adaptation still gives unstable behavior, my next recommendation is:
- LoRA on cross-attention
- plus LoRA on the last 2 decoder blocks
This is not just a generic modern preference. Recent low-resource Indic OCR work uses LoRA-style parameter-efficient adaptation, and the original LoRA paper explains why adapting only a small low-rank slice is often more stable and much cheaper than broad full fine-tuning. (aclanthology.org)
So if the broader partial fine-tuning still collapses, I would move to LoRA before touching the encoder.
Solution 8. Keep 210 samples for debugging, but move to 1000+ once decoder behavior stabilizes
Your 210-sample dataset is still useful. It is a good debugging set.
But it is not enough to expect stable OCR behavior from a large hybrid model across both printed and handwritten Hindi.
That is why recent low-resource Indic OCR work uses PEFT and synthetic data, and why public Indic handwriting resources are much larger. This is not a sign that your approach is bad; it is a sign that your current data regime is very small. (cvit.iiit.ac.in)
So my recommendation is:
- keep the 210 samples for trainer and decoder debugging,
- move to 1000+ line crops once the decoder stops collapsing into sentinels and repetitions.
My final recommendation stack
If I compress everything into the clearest sequence, it is this:
First
Keep the architecture.
Second
Train:
- encoder frozen,
- all cross-attention trainable,
- last 2 decoder blocks trainable,
lm_head,shared,- projection.
Third
Suppress <extra_id_n> tokens during validation and inference with bad_words_ids.
Fourth
Use easy vs hard probe sets and log:
- sentinel-token rate,
- length ratio,
- target token length.
Fifth
If instability persists, move to LoRA on decoder-side blocks.
Sixth
Scale to 1000+ line crops once decoder behavior becomes sane.
My one-sentence summary
Your problem is not “wrong tokenizer.” It is decoder fallback to T5 priors plus under-adapted generation on a tiny mixed dataset.
So my recommended solution is:
Keep the current setup, broaden decoder-side adaptation to the last 2 decoder blocks plus all cross-attention, suppress sentinel tokens during validation, and use LoRA next if the decoder is still unstable.
That is the path I think best fits what you are seeing now.
Discussion in the ATmosphere