External Publication
Visit Post

Looking for guidance. Trying to create a model with TrOCR's encoder + Google's mT5 multilingual decoder but model fails to overfit on a single data sample

Hugging Face Forums [Unofficial] April 1, 2026
Source

Well, I’m not sure if the following patch will work as-is, but for now:


For your case, the right conclusion is:

210 samples is a valid proof-of-concept dataset, but only if you treat it as a pipeline test, not a quality benchmark. Your current notebook is close enough to keep. The main problem is no longer encoder-decoder wiring. The main problem is the training control logic and the fact that 160 training lines is far too little data for full fine-tuning of hundreds of millions of parameters. In warm-started encoder-decoder hybrids, the cross-attention bridge can be randomly initialized and needs downstream fine-tuning, and in mT5 the decoder starts from pad_token_id, which matches the direction you already fixed in your notebook. (Hugging Face)

My direct answer to your three questions

1. Is 210 samples okay?

Yes, for a proof that the pipeline is capable of learning. No, for judging final model quality or choosing the best architecture.

A 210-line dataset can answer a narrow but useful question:

“Can this TrOCR-encoder plus mT5-decoder pipeline learn stable Hindi line recognition on real data without collapsing?”

That is a good proof-of-concept question. It is much narrower than “Is this good OCR?” and that is the correct way to use 210 samples. Public TrOCR documentation says the raw model is intended for single text-line images , so your use of line crops is aligned with the model family’s intended use. At the same time, public Indic OCR resources are much larger. IIIT-INDIC-HW-WORDS reports 872K handwritten instances across 8 Indic scripts, and recent low-resource Indic OCR work explicitly leans on synthetic data or parameter-efficient adaptation because small real datasets are not enough by themselves. (Hugging Face)

So the right framing is:

  • 210 lines = enough to validate that the training recipe is sane.
  • 1000+ lines = much better for deciding whether the model scales.
  • many thousands = where quality conclusions start to matter.

2. What CER and WER should you aim for?

For your current dataset size, use CER as the primary metric and WER as a secondary metric.

That is because CER is character-level edit distance and is smoother on tiny validation sets, while WER is harsher and noisier on short lines. Hugging Face’s CER implementation defines CER as character-level edit distance normalized by reference length, and WER is the word-level analogue based on substitutions, deletions, and insertions. Both are “lower is better,” and both can exceed 1.0 when insertions are large. So “CER below 1.0” is not a meaningful success threshold. A CER around 0.98 is still very poor. (GitHub)

These are the thresholds I would use for your proof-of-concept , and these are engineering thresholds , not official published cutoffs:

Minimum green light

  • validation CER < 0.50
  • validation WER < 0.80
  • predictions are readable
  • no repeated-character collapse
  • both printed and handwritten lines improve

Better green light

  • validation CER around 0.25 to 0.40
  • validation WER around 0.45 to 0.70

Strong green light

  • validation CER < 0.20 to 0.25
  • validation WER < 0.40 to 0.50

For 210 mixed-domain samples, I would call the pipeline “ready to scale” once you can reliably beat CER 0.5 and show visibly readable predictions on both printed and handwritten validation lines. The important word is reliably. One lucky split is not enough.

3. Which layers should you freeze, and how should the training class be designed?

For your exact architecture, freezing the decoder and training only the encoder is the wrong direction.

The decoder side contains:

  • the Hindi text generation behavior,
  • the new cross-attention bridge that connects image features to text generation,
  • and the autoregressive dynamics that are currently causing metric instability.

Hugging Face’s encoder-decoder docs explicitly say cross-attention may be randomly initialized in these hybrids and must be fine-tuned downstream. That means the most important adaptation is usually decoder-side cross-attention and output-side behavior , not encoder-only retraining. (Hugging Face)

So for your notebook , the correct first training stage is:

  • freeze the entire visual encoder
  • train enc_to_dec_proj if it exists
  • train decoder cross-attention
  • train lm_head
  • train shared embeddings
  • optionally train decoder layer norms

That is the right first-stage design for your current notebook.

My thoughts after checking your training cells

Your notebook is now in a much better place than before. These are the good parts:

  • you moved to trocr-small-stage1 plus mt5-small
  • the custom wrapper is now sane enough to test
  • the overfit test already proved the bridge can learn
  • the real trainer freezes the encoder and trains decoder-side bridge/output parameters

That is the right direction.

The weak part is the trainer design , not the architecture.

The biggest trainer problem

Right now, in your notebook:

  • validation loss is computed every epoch,
  • CER/WER are only computed every 5 epochs,
  • the scheduler follows validation loss ,
  • but the best-checkpoint logic follows CER.

That is a mismatch.

For seq2seq OCR, teacher-forced loss and free-generation CER can move in different directions. The model can keep lowering validation loss while actual OCR output gets worse. That is especially common in tiny-data autoregressive setups. So a scheduler driven by val_loss and checkpointing driven by CER can easily train past the best OCR model.

That is the main reason I do not trust the 50-epoch result as a true judgment of the architecture.

What I think your current CER curve actually means

You said:

  • CER started around 1.2 to 1.3
  • improved to about 0.98
  • then worsened to about 1.5

That pattern usually means:

  • the model is learning something at first,
  • then overfitting or decoding drift sets in,
  • and later epochs add insertions or repetitive garbage.

Since CER is edit-distance based, late insertions can easily push it above 1.0. So this is not “almost good but not quite.” It is “still poor overall, with a brief early improvement that was not preserved.” (GitHub)

The good news is that this pattern usually points to training-control problems , not “your architecture cannot work.”

The exact changes I would make now

1. Compute CER and WER every epoch

On a validation set of about 45 to 50 lines, the extra compute is small. The benefit is large.

In your trainer, set:

generate_every_n_epochs = 1

That alone makes your best-epoch detection much more trustworthy.

2. Use CER as the one metric that controls training

Use validation CER for all three:

  • scheduler stepping
  • best-checkpoint saving
  • early stopping

Do not split those across val_loss and CER.

For this dataset size and task, CER is the best control metric. WER is still useful, but mainly as a reporting metric.

3. Add early stopping

Do not run 50 fixed epochs on 160 training lines.

Use:

  • num_epochs = 20 or 25
  • patience = 5
  • min_delta = 0.005 on CER

The best checkpoint will probably appear earlier than epoch 50. Right now your notebook is not designed to stop there.

4. Use parameter groups, not one flat AdamW group

In your notebook, all trainable parameters currently use one LR and one weight decay. That is too blunt.

Hugging Face’s training docs note that biases and LayerNorm parameters are usually excluded from weight decay. I would go one step further and also give the bridge and cross-attention a slightly higher learning rate than the rest of the decoder-side trainable weights. (Hugging Face)

A good split is:

  • bridge and cross-attention : LR 2e-4, weight decay 0.01
  • lm_head and other trainable decoder weights : LR 1e-4, weight decay 0.01
  • biases, norms, shared embeddings : LR 1e-4, weight decay 0.0

5. Normalize text before CER and WER

This matters more for Hindi than people expect.

Before computing metrics, normalize both prediction and reference with:

  • Unicode NFC normalization
  • .strip()
  • whitespace collapse

That removes avoidable Unicode and spacing noise from the metric.

6. Check whether max_length=64 is truncating your labels

In your dataset class, you set max_length=64.

That may be too small for some Hindi lines. If targets are being truncated, the model can never predict the full reference correctly, and your metrics are capped by preprocessing rather than training.

Before the next run, print:

  • max tokenized label length
  • 95th percentile tokenized length
  • number of samples hitting max_length

If many lines are hitting 64, increase it.

7. Split metrics by printed versus handwritten

This is essential in your case.

Because your dataset mixes printed and handwritten Hindi lines, one combined CER can hide a lot. A model could improve strongly on printed lines and fail on handwritten lines, while the aggregate metric still looks “okay.”

So report:

  • overall CER/WER
  • printed-only CER/WER
  • handwritten-only CER/WER

That will tell you much more than a single global score.

The freeze schedule I recommend

This is the schedule I would actually use.

Stage A. Your first real-data training stage

Freeze:

  • entire encoder

Train:

  • enc_to_dec_proj
  • decoder cross-attention
  • lm_head
  • shared embeddings
  • decoder norms

This is the best first-stage setup for your current notebook.

Stage B. If Stage A improves but plateaus

Keep encoder frozen.

Also unfreeze:

  • the last 2 decoder blocks

That gives the text side more flexibility without exploding trainable parameters.

Stage C. Only after 1000+ samples

Only then unfreeze:

  • the top 2 encoder blocks
  • at a much smaller LR, like 1e-5

Do not full-fine-tune the encoder now.

What I would not do

Do not freeze the decoder and train encoder only. That freezes exactly the side that needs to adapt to Hindi output and to the new cross-attention bridge. (Hugging Face)

Whether you should use 210 or 1000 next

My answer is:

  • keep the 210-sample experiment
  • but use it only as a training-recipe validation run
  • do not use it to choose the final model design

Once the trainer is fixed, if you can get below about CER 0.5 and the outputs look readable on both printed and handwritten lines, then I would move to 1000+ line crops immediately.

If, after fixing the trainer, CER still stays around 1.0, I would not scale yet. I would first try a parameter-efficient adaptation path , not full fine-tuning.

That recommendation is aligned with recent low-resource OCR work. Nayana uses LoRA to adapt OCR models across low-resource Indic languages, including Hindi, and the original LoRA paper explains why freezing most pretrained weights and training a small adapter can be more stable and much cheaper than full fine-tuning. (ACL Anthology)

What result would convince me that your pipeline is ready

For this specific proof-of-concept, I would call it “ready for larger datasets” if all of these are true:

  • validation CER is below 0.50
  • validation WER is below 0.80
  • outputs are readable and mostly non-repetitive
  • printed and handwritten subsets both improve
  • the best checkpoint is found early and can be restored
  • the result is stable across at least 2 different random splits

That last point matters because your current dataset is tiny. One split can flatter or punish you.

My honest bottom line

Your current case looks like this:

  • The architecture is no longer the main problem.
  • The trainer is now the main problem.
  • 210 samples is enough to validate the pipeline, but not enough to judge model quality.
  • Full fine-tuning ~400M parameters on 160 training lines is the wrong experiment.
  • Freezing the decoder and training the encoder only is the wrong direction for this model.
  • A staged frozen-encoder, decoder-side adaptation strategy is the correct next step. (Hugging Face)

So my advice is:

  1. keep the current architecture,
  2. fix the trainer,
  3. rerun the 210-line proof with CER-driven control, early stopping, and per-domain metrics,
  4. then move to 1000+ samples only after you get a stable result.

Here is the minimal patch I would make to your last two cells only.

This patch keeps your current model and dataset flow intact. It changes the trainer so that CER becomes the main control metric , generation runs every epoch , the frozen-encoder / decoder-side adaptation strategy stays in place, and AdamW uses a more standard weight-decay split for norm and bias parameters. That matches how warm-started vision-encoder-decoder hybrids typically need decoder-side fine-tuning, how mT5 expects decoder starts from pad_token_id, and how Hugging Face examples group AdamW parameters. (Hugging Face)

It also fixes the specific failure mode your current trainer has: loss can keep improving while free-generation OCR gets worse. CER is a character-level edit-distance metric where lower is better , and insertion-heavy outputs can drive it above 1.0, so it is the better signal for checkpointing and early stopping in your setup. (GitHub)

Replace Cell 14 with this

import os
import re
import random
import unicodedata
import torch
import torch.nn as nn
from jiwer import wer as compute_wer_score


def normalize_text(text: str) -> str:
    text = unicodedata.normalize("NFC", text)
    text = text.strip()
    text = re.sub(r"\s+", " ", text)
    return text


class HindiOCRTrainer:

    def __init__(
        self,
        model,
        train_loader,
        val_loader,
        tokenizer,
        device,
        output_dir,
        num_epochs=20,
        learning_rate=1e-4,
        grad_clip_norm=1.0,
        patience=5,
        min_delta=0.005,
        max_new_tokens=64,
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.tokenizer = tokenizer
        self.device = device
        self.output_dir = output_dir
        self.num_epochs = num_epochs
        self.grad_clip_norm = grad_clip_norm
        self.patience = patience
        self.min_delta = min_delta
        self.max_new_tokens = max_new_tokens

        os.makedirs(output_dir, exist_ok=True)

        # Restore dropout after single-sample overfit test
        dropout_rate = model.mt5.config.dropout_rate
        for m in model.modules():
            if isinstance(m, nn.Dropout):
                m.p = dropout_rate

        # -----------------------------
        # Freeze strategy: Stage A
        # -----------------------------
        # Freeze full encoder
        for p in model.encoder.parameters():
            p.requires_grad = False

        # Train decoder cross-attention + output-side params
        for name, p in model.mt5.named_parameters():
            p.requires_grad = (
                ("EncDecAttention" in name) or
                ("lm_head" in name) or
                ("shared" in name) or
                ("layer_norm" in name)
            )

        # Train encoder->decoder projection if present
        if model.enc_to_dec_proj is not None:
            for p in model.enc_to_dec_proj.parameters():
                p.requires_grad = True

        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        print(f"Trainable params: {trainable:,} / {total:,}")

        # -----------------------------
        # Optimizer with param groups
        # -----------------------------
        bridge_params = []
        decay_params = []
        no_decay_params = []

        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue

            if "enc_to_dec_proj" in name or "EncDecAttention" in name:
                bridge_params.append(p)
            elif any(x in name for x in ["bias", "LayerNorm.weight", "layer_norm.weight", "shared"]):
                no_decay_params.append(p)
            else:
                decay_params.append(p)

        optimizer_groups = []
        if bridge_params:
            optimizer_groups.append({
                "params": bridge_params,
                "lr": 2e-4,
                "weight_decay": 0.01,
            })
        if decay_params:
            optimizer_groups.append({
                "params": decay_params,
                "lr": learning_rate,
                "weight_decay": 0.01,
            })
        if no_decay_params:
            optimizer_groups.append({
                "params": no_decay_params,
                "lr": learning_rate,
                "weight_decay": 0.0,
            })

        self.optimizer = torch.optim.AdamW(optimizer_groups)

        # CER is the control signal, not val_loss
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode="min",
            factor=0.5,
            patience=2,
            min_lr=1e-6,
        )

        self.best_cer = float("inf")
        self.best_wer = float("inf")
        self.best_checkpoint_path = os.path.join(self.output_dir, "best_model.pt")
        self.training_log = []
        self.bad_epochs = 0

    # ============================================================
    # TRAIN
    # ============================================================
    def _train_one_epoch(self):
        self.model.train()
        total_loss = 0.0

        for batch in self.train_loader:
            pixel_values = batch["pixel_values"].to(self.device)
            labels = batch["labels"].to(self.device)

            self.optimizer.zero_grad()

            outputs = self.model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            loss.backward()

            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
            self.optimizer.step()

            total_loss += loss.item()

        return total_loss / len(self.train_loader)

    # ============================================================
    # VALIDATE
    # ============================================================
    def _validate(self, epoch):
        self.model.eval()
        total_loss = 0.0
        all_predictions = []
        all_ground_truths = []

        with torch.no_grad():
            for batch in self.val_loader:
                pixel_values = batch["pixel_values"].to(self.device)
                labels = batch["labels"].to(self.device)

                outputs = self.model(pixel_values=pixel_values, labels=labels)
                total_loss += outputs.loss.item()

                generated_ids = self.model.generate(
                    pixel_values=pixel_values,
                    max_new_tokens=self.max_new_tokens,
                    num_beams=1,
                    do_sample=False,
                )

                clean_labels = labels.clone()
                clean_labels[clean_labels == -100] = self.tokenizer.pad_token_id

                pred_texts = self.tokenizer.batch_decode(
                    generated_ids, skip_special_tokens=True
                )
                gt_texts = self.tokenizer.batch_decode(
                    clean_labels, skip_special_tokens=True
                )

                pred_texts = [normalize_text(x) for x in pred_texts]
                gt_texts = [normalize_text(x) for x in gt_texts]

                all_predictions.extend(pred_texts)
                all_ground_truths.extend(gt_texts)

        avg_val_loss = total_loss / len(self.val_loader)
        cer = self._compute_cer(all_predictions, all_ground_truths)
        wer = compute_wer_score(all_ground_truths, all_predictions)
        exact_match = sum(p == g for p, g in zip(all_predictions, all_ground_truths)) / max(1, len(all_predictions))

        return avg_val_loss, cer, wer, exact_match, all_predictions, all_ground_truths

    # ============================================================
    # CER
    # ============================================================
    def _compute_cer(self, predictions, ground_truths):
        total_edits = 0
        total_chars = 0

        for pred, gt in zip(predictions, ground_truths):
            total_edits += self._edit_distance(pred, gt)
            total_chars += len(gt)

        return 0.0 if total_chars == 0 else total_edits / total_chars

    def _edit_distance(self, s1, s2):
        m, n = len(s1), len(s2)
        dp = list(range(n + 1))
        for i in range(1, m + 1):
            prev = dp[0]
            dp[0] = i
            for j in range(1, n + 1):
                temp = dp[j]
                if s1[i - 1] == s2[j - 1]:
                    dp[j] = prev
                else:
                    dp[j] = 1 + min(prev, dp[j], dp[j - 1])
                prev = temp
        return dp[n]

    # ============================================================
    # CHECKPOINT
    # ============================================================
    def _save_checkpoint(self, epoch, is_best=False):
        checkpoint = {
            "epoch": epoch,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler.state_dict(),
            "best_cer": self.best_cer,
            "best_wer": self.best_wer,
            "training_log": self.training_log,
        }

        if is_best:
            torch.save(checkpoint, self.best_checkpoint_path)
            print(f"  ✅ Best model saved → {self.best_checkpoint_path}")

        if epoch % 5 == 0:
            periodic_path = os.path.join(self.output_dir, f"checkpoint_epoch_{epoch}.pt")
            torch.save(checkpoint, periodic_path)
            print(f"  💾 Periodic checkpoint saved → {periodic_path}")

    # ============================================================
    # MAIN LOOP
    # ============================================================
    def train(self):
        print("=" * 60)
        print(f"Starting training for {self.num_epochs} epochs")
        print("Generation metrics computed every epoch")
        print("Primary control metric: CER")
        print("Target CER for PoC: < 0.50")
        print("=" * 60)

        for epoch in range(1, self.num_epochs + 1):
            avg_train_loss = self._train_one_epoch()
            avg_val_loss, cer, wer, exact_match, predictions, ground_truths = self._validate(epoch)

            # Scheduler follows CER, not val_loss
            self.scheduler.step(cer)
            current_lr = self.optimizer.param_groups[0]["lr"]

            log_entry = {
                "epoch": epoch,
                "train_loss": avg_train_loss,
                "val_loss": avg_val_loss,
                "cer": cer,
                "wer": wer,
                "exact_match": exact_match,
                "lr": current_lr,
            }
            self.training_log.append(log_entry)

            print(
                f"Epoch {epoch:>2}/{self.num_epochs} | "
                f"Train Loss: {avg_train_loss:.4f} | "
                f"Val Loss: {avg_val_loss:.4f} | "
                f"CER: {cer:.4f} | "
                f"WER: {wer:.4f} | "
                f"EM: {exact_match:.3f} | "
                f"LR: {current_lr:.2e}"
            )

            # Qualitative preview
            print("  Sample predictions:")
            indices = random.sample(range(len(predictions)), min(3, len(predictions)))
            for i in indices:
                print(f"    GT:   '{ground_truths[i]}'")
                print(f"    Pred: '{predictions[i]}'")

            improved = cer < (self.best_cer - self.min_delta)

            if improved:
                self.best_cer = cer
                self.best_wer = wer
                self.bad_epochs = 0
                self._save_checkpoint(epoch, is_best=True)
                print(f"  🎯 New best CER: {cer:.4f}")
                if cer < 0.50:
                    print("  ✅ TARGET REACHED: CER < 0.50")
            else:
                self.bad_epochs += 1
                print(f"  No CER improvement. Patience: {self.bad_epochs}/{self.patience}")

            self._save_checkpoint(epoch, is_best=False)

            if self.bad_epochs >= self.patience:
                print("  ⏹️ Early stopping triggered.")
                break

        print("\n" + "=" * 60)
        print("Training complete.")
        print(f"Best CER: {self.best_cer:.4f}")
        print(f"Best WER: {self.best_wer:.4f}")

        if os.path.exists(self.best_checkpoint_path):
            checkpoint = torch.load(self.best_checkpoint_path, map_location=self.device)
            self.model.load_state_dict(checkpoint["model_state"])
            print("✅ Restored best checkpoint into model.")

        return self.training_log

    # ============================================================
    # OPTIONAL RESUME
    # ============================================================
    def load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state"])
        self.best_cer = checkpoint["best_cer"]
        self.best_wer = checkpoint["best_wer"]
        self.training_log = checkpoint.get("training_log", [])
        start_epoch = checkpoint["epoch"] + 1
        print(f"✅ Resumed from epoch {checkpoint['epoch']} | Best CER so far: {self.best_cer:.4f}")
        return start_epoch

Replace Cell 15 with this

trainer = HindiOCRTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    tokenizer=tokenizer,
    device=device,
    output_dir="/content/drive/MyDrive/trocr-checkpoints",
    num_epochs=20,
    learning_rate=1e-4,
    grad_clip_norm=1.0,
    patience=5,
    min_delta=0.005,
    max_new_tokens=64,
)

training_log = trainer.train()

What changed, and why

The smallest important changes are these:

  • CER/WER now run every epoch , so the best OCR checkpoint cannot be skipped between 5-epoch intervals.
  • Scheduler now follows CER , not validation loss.
  • Early stopping stops the run once CER stops improving.
  • Parameter groups give the bridge and cross-attention a slightly higher LR, while excluding norm and bias-style parameters from weight decay in the usual Hugging Face pattern. (GitHub)
  • Text normalization is applied before metrics, which matters for Unicode-heavy scripts like Devanagari.
  • The class still keeps your current frozen encoder + decoder-side adaptation strategy, which is the right first stage for a warm-started vision-text hybrid where decoder-side cross-attention is the newly learned bridge. (Hugging Face)

One small check before you rerun

Your dataset cell still uses max_length=64 for labels. Before the next run, quickly inspect how many targets are hitting that cap. If a noticeable fraction of your Hindi lines are truncated at 64 tokens, raise it first. Otherwise the trainer can improve while the metric ceiling stays artificially low.

Use this once, anywhere after dataset creation:

lengths = []
for i in range(len(train_dataset)):
    ids = train_dataset[i]["labels"]
    valid = (ids != -100).sum().item()
    lengths.append(valid)

print("max label tokens:", max(lengths))
print("p95 label tokens:", sorted(lengths)[int(0.95 * len(lengths))])
print("num hitting 64:", sum(x >= 64 for x in lengths))

What I would expect after this patch

If the pipeline is healthy, you should see this pattern:

  • training loss goes down,
  • validation CER improves in the first several epochs,
  • the best checkpoint appears before the final epoch,
  • and early stopping restores that checkpoint instead of letting later degradation define the run.

If after this patch CER still stays around 1.0 or worse, I would not jump to “different tokenizer” or “different architecture.” I would first test one of these two moves:

  • keep the same freeze strategy and move to 1000+ line crops , or
  • keep the same data and switch to LoRA-style adaptation on the decoder-side trainable blocks, which is a common low-resource adaptation path in recent Indic OCR work. (ACL Anthology)

Discussion in the ATmosphere

Loading comments...