{
  "$type": "site.standard.document",
  "bskyPostRef": {
    "cid": "bafyreicpaitftp2g4nngo5vepqqszy26fsk3isolgaiek6rs2vmfc7l56e",
    "uri": "at://did:plc:pgryn3ephfd2xgft23qokfzt/app.bsky.feed.post/3miflmuahptg2"
  },
  "path": "/t/looking-for-guidance-trying-to-create-a-model-with-trocrs-encoder-googles-mt5-multilingual-decoder-but-model-fails-to-overfit-on-a-single-data-sample/174634#post_4",
  "publishedAt": "2026-04-01T00:52:43.000Z",
  "site": "https://discuss.huggingface.co",
  "tags": [
    "Hugging Face",
    "GitHub",
    "ACL Anthology"
  ],
  "textContent": "Well, I’m not sure if the following patch will work as-is, but for now:\n\n* * *\n\nFor your case, the right conclusion is:\n\n**210 samples is a valid proof-of-concept dataset, but only if you treat it as a pipeline test, not a quality benchmark.**\nYour current notebook is close enough to keep. The main problem is no longer encoder-decoder wiring. The main problem is the **training control logic** and the fact that **160 training lines is far too little data for full fine-tuning of hundreds of millions of parameters**. In warm-started encoder-decoder hybrids, the cross-attention bridge can be randomly initialized and needs downstream fine-tuning, and in mT5 the decoder starts from `pad_token_id`, which matches the direction you already fixed in your notebook. (Hugging Face)\n\n## My direct answer to your three questions\n\n### 1. Is 210 samples okay?\n\nYes, **for a proof that the pipeline is capable of learning**.\nNo, **for judging final model quality or choosing the best architecture**.\n\nA 210-line dataset can answer a narrow but useful question:\n\n> “Can this TrOCR-encoder plus mT5-decoder pipeline learn stable Hindi line recognition on real data without collapsing?”\n\nThat is a good proof-of-concept question. It is much narrower than “Is this good OCR?” and that is the correct way to use 210 samples. Public TrOCR documentation says the raw model is intended for **single text-line images** , so your use of line crops is aligned with the model family’s intended use. At the same time, public Indic OCR resources are much larger. IIIT-INDIC-HW-WORDS reports **872K handwritten instances** across 8 Indic scripts, and recent low-resource Indic OCR work explicitly leans on synthetic data or parameter-efficient adaptation because small real datasets are not enough by themselves. (Hugging Face)\n\nSo the right framing is:\n\n  * **210 lines** = enough to validate that the training recipe is sane.\n  * **1000+ lines** = much better for deciding whether the model scales.\n  * **many thousands** = where quality conclusions start to matter.\n\n\n\n### 2. What CER and WER should you aim for?\n\nFor your current dataset size, use **CER as the primary metric** and **WER as a secondary metric**.\n\nThat is because CER is character-level edit distance and is smoother on tiny validation sets, while WER is harsher and noisier on short lines. Hugging Face’s CER implementation defines CER as character-level edit distance normalized by reference length, and WER is the word-level analogue based on substitutions, deletions, and insertions. Both are “lower is better,” and both can exceed `1.0` when insertions are large. So “CER below 1.0” is not a meaningful success threshold. A CER around `0.98` is still very poor. (GitHub)\n\nThese are the thresholds I would use for **your proof-of-concept** , and these are **engineering thresholds** , not official published cutoffs:\n\n**Minimum green light**\n\n  * validation **CER < 0.50**\n  * validation **WER < 0.80**\n  * predictions are readable\n  * no repeated-character collapse\n  * both printed and handwritten lines improve\n\n\n\n**Better green light**\n\n  * validation **CER around 0.25 to 0.40**\n  * validation **WER around 0.45 to 0.70**\n\n\n\n**Strong green light**\n\n  * validation **CER < 0.20 to 0.25**\n  * validation **WER < 0.40 to 0.50**\n\n\n\nFor 210 mixed-domain samples, I would call the pipeline “ready to scale” once you can **reliably beat CER 0.5** and **show visibly readable predictions** on both printed and handwritten validation lines. The important word is **reliably**. One lucky split is not enough.\n\n### 3. Which layers should you freeze, and how should the training class be designed?\n\nFor your exact architecture, **freezing the decoder and training only the encoder is the wrong direction**.\n\nThe decoder side contains:\n\n  * the Hindi text generation behavior,\n  * the new cross-attention bridge that connects image features to text generation,\n  * and the autoregressive dynamics that are currently causing metric instability.\n\n\n\nHugging Face’s encoder-decoder docs explicitly say cross-attention may be randomly initialized in these hybrids and must be fine-tuned downstream. That means the most important adaptation is usually **decoder-side cross-attention and output-side behavior** , not encoder-only retraining. (Hugging Face)\n\nSo for **your notebook** , the correct first training stage is:\n\n  * freeze the **entire visual encoder**\n  * train `enc_to_dec_proj` if it exists\n  * train decoder **cross-attention**\n  * train `lm_head`\n  * train shared embeddings\n  * optionally train decoder layer norms\n\n\n\nThat is the right first-stage design for your current notebook.\n\n## My thoughts after checking your training cells\n\nYour notebook is now in a much better place than before. These are the good parts:\n\n  * you moved to `trocr-small-stage1` plus `mt5-small`\n  * the custom wrapper is now sane enough to test\n  * the overfit test already proved the bridge can learn\n  * the real trainer freezes the encoder and trains decoder-side bridge/output parameters\n\n\n\nThat is the right direction.\n\nThe weak part is the **trainer design** , not the architecture.\n\n### The biggest trainer problem\n\nRight now, in your notebook:\n\n  * **validation loss** is computed every epoch,\n  * **CER/WER** are only computed every 5 epochs,\n  * the scheduler follows **validation loss** ,\n  * but the best-checkpoint logic follows **CER**.\n\n\n\nThat is a mismatch.\n\nFor seq2seq OCR, teacher-forced loss and free-generation CER can move in different directions. The model can keep lowering validation loss while actual OCR output gets worse. That is especially common in tiny-data autoregressive setups. So a scheduler driven by `val_loss` and checkpointing driven by `CER` can easily train past the best OCR model.\n\nThat is the main reason I do not trust the 50-epoch result as a true judgment of the architecture.\n\n## What I think your current CER curve actually means\n\nYou said:\n\n  * CER started around `1.2` to `1.3`\n  * improved to about `0.98`\n  * then worsened to about `1.5`\n\n\n\nThat pattern usually means:\n\n  * the model is learning something at first,\n  * then overfitting or decoding drift sets in,\n  * and later epochs add insertions or repetitive garbage.\n\n\n\nSince CER is edit-distance based, late insertions can easily push it above `1.0`. So this is not “almost good but not quite.” It is “still poor overall, with a brief early improvement that was not preserved.” (GitHub)\n\nThe good news is that this pattern usually points to **training-control problems** , not “your architecture cannot work.”\n\n## The exact changes I would make now\n\n## 1. Compute CER and WER every epoch\n\nOn a validation set of about 45 to 50 lines, the extra compute is small. The benefit is large.\n\nIn your trainer, set:\n\n\n    generate_every_n_epochs = 1\n\n\nThat alone makes your best-epoch detection much more trustworthy.\n\n## 2. Use CER as the one metric that controls training\n\nUse **validation CER** for all three:\n\n  * scheduler stepping\n  * best-checkpoint saving\n  * early stopping\n\n\n\nDo not split those across `val_loss` and CER.\n\nFor this dataset size and task, CER is the best control metric. WER is still useful, but mainly as a reporting metric.\n\n## 3. Add early stopping\n\nDo not run 50 fixed epochs on 160 training lines.\n\nUse:\n\n  * `num_epochs = 20` or `25`\n  * `patience = 5`\n  * `min_delta = 0.005` on CER\n\n\n\nThe best checkpoint will probably appear earlier than epoch 50. Right now your notebook is not designed to stop there.\n\n## 4. Use parameter groups, not one flat AdamW group\n\nIn your notebook, all trainable parameters currently use one LR and one weight decay. That is too blunt.\n\nHugging Face’s training docs note that biases and LayerNorm parameters are usually excluded from weight decay. I would go one step further and also give the bridge and cross-attention a slightly higher learning rate than the rest of the decoder-side trainable weights. (Hugging Face)\n\nA good split is:\n\n  * **bridge and cross-attention** : LR `2e-4`, weight decay `0.01`\n  * **lm_head and other trainable decoder weights** : LR `1e-4`, weight decay `0.01`\n  * **biases, norms, shared embeddings** : LR `1e-4`, weight decay `0.0`\n\n\n\n## 5. Normalize text before CER and WER\n\nThis matters more for Hindi than people expect.\n\nBefore computing metrics, normalize both prediction and reference with:\n\n  * Unicode NFC normalization\n  * `.strip()`\n  * whitespace collapse\n\n\n\nThat removes avoidable Unicode and spacing noise from the metric.\n\n## 6. Check whether `max_length=64` is truncating your labels\n\nIn your dataset class, you set `max_length=64`.\n\nThat may be too small for some Hindi lines. If targets are being truncated, the model can never predict the full reference correctly, and your metrics are capped by preprocessing rather than training.\n\nBefore the next run, print:\n\n  * max tokenized label length\n  * 95th percentile tokenized length\n  * number of samples hitting `max_length`\n\n\n\nIf many lines are hitting 64, increase it.\n\n## 7. Split metrics by printed versus handwritten\n\nThis is essential in your case.\n\nBecause your dataset mixes printed and handwritten Hindi lines, one combined CER can hide a lot. A model could improve strongly on printed lines and fail on handwritten lines, while the aggregate metric still looks “okay.”\n\nSo report:\n\n  * overall CER/WER\n  * printed-only CER/WER\n  * handwritten-only CER/WER\n\n\n\nThat will tell you much more than a single global score.\n\n## The freeze schedule I recommend\n\nThis is the schedule I would actually use.\n\n### Stage A. Your first real-data training stage\n\nFreeze:\n\n  * entire encoder\n\n\n\nTrain:\n\n  * `enc_to_dec_proj`\n  * decoder cross-attention\n  * `lm_head`\n  * shared embeddings\n  * decoder norms\n\n\n\nThis is the best first-stage setup for your current notebook.\n\n### Stage B. If Stage A improves but plateaus\n\nKeep encoder frozen.\n\nAlso unfreeze:\n\n  * the **last 2 decoder blocks**\n\n\n\nThat gives the text side more flexibility without exploding trainable parameters.\n\n### Stage C. Only after 1000+ samples\n\nOnly then unfreeze:\n\n  * the **top 2 encoder blocks**\n  * at a much smaller LR, like `1e-5`\n\n\n\nDo **not** full-fine-tune the encoder now.\n\n### What I would not do\n\nDo not freeze the decoder and train encoder only. That freezes exactly the side that needs to adapt to Hindi output and to the new cross-attention bridge. (Hugging Face)\n\n## Whether you should use 210 or 1000 next\n\nMy answer is:\n\n  * keep the **210-sample experiment**\n  * but use it only as a **training-recipe validation run**\n  * do **not** use it to choose the final model design\n\n\n\nOnce the trainer is fixed, if you can get below about `CER 0.5` and the outputs look readable on both printed and handwritten lines, then I would move to **1000+ line crops** immediately.\n\nIf, after fixing the trainer, CER still stays around `1.0`, I would not scale yet. I would first try a **parameter-efficient adaptation path** , not full fine-tuning.\n\nThat recommendation is aligned with recent low-resource OCR work. Nayana uses **LoRA** to adapt OCR models across low-resource Indic languages, including Hindi, and the original LoRA paper explains why freezing most pretrained weights and training a small adapter can be more stable and much cheaper than full fine-tuning. (ACL Anthology)\n\n## What result would convince me that your pipeline is ready\n\nFor this specific proof-of-concept, I would call it “ready for larger datasets” if all of these are true:\n\n  * validation **CER is below 0.50**\n  * validation **WER is below 0.80**\n  * outputs are readable and mostly non-repetitive\n  * printed and handwritten subsets both improve\n  * the best checkpoint is found early and can be restored\n  * the result is stable across at least **2 different random splits**\n\n\n\nThat last point matters because your current dataset is tiny. One split can flatter or punish you.\n\n## My honest bottom line\n\nYour current case looks like this:\n\n  * **The architecture is no longer the main problem.**\n  * **The trainer is now the main problem.**\n  * **210 samples is enough to validate the pipeline, but not enough to judge model quality.**\n  * **Full fine-tuning ~400M parameters on 160 training lines is the wrong experiment.**\n  * **Freezing the decoder and training the encoder only is the wrong direction for this model.**\n  * **A staged frozen-encoder, decoder-side adaptation strategy is the correct next step.** (Hugging Face)\n\n\n\nSo my advice is:\n\n  1. keep the current architecture,\n  2. fix the trainer,\n  3. rerun the 210-line proof with CER-driven control, early stopping, and per-domain metrics,\n  4. then move to 1000+ samples only after you get a stable result.\n\n\n\n* * *\n\nHere is the **minimal patch** I would make to your **last two cells only**.\n\nThis patch keeps your current model and dataset flow intact. It changes the trainer so that **CER becomes the main control metric** , generation runs **every epoch** , the **frozen-encoder / decoder-side adaptation** strategy stays in place, and AdamW uses a more standard **weight-decay split** for norm and bias parameters. That matches how warm-started vision-encoder-decoder hybrids typically need decoder-side fine-tuning, how mT5 expects decoder starts from `pad_token_id`, and how Hugging Face examples group AdamW parameters. (Hugging Face)\n\nIt also fixes the specific failure mode your current trainer has: loss can keep improving while free-generation OCR gets worse. CER is a character-level edit-distance metric where **lower is better** , and insertion-heavy outputs can drive it above `1.0`, so it is the better signal for checkpointing and early stopping in your setup. (GitHub)\n\n## Replace Cell 14 with this\n\n\n    import os\n    import re\n    import random\n    import unicodedata\n    import torch\n    import torch.nn as nn\n    from jiwer import wer as compute_wer_score\n\n\n    def normalize_text(text: str) -> str:\n        text = unicodedata.normalize(\"NFC\", text)\n        text = text.strip()\n        text = re.sub(r\"\\s+\", \" \", text)\n        return text\n\n\n    class HindiOCRTrainer:\n\n        def __init__(\n            self,\n            model,\n            train_loader,\n            val_loader,\n            tokenizer,\n            device,\n            output_dir,\n            num_epochs=20,\n            learning_rate=1e-4,\n            grad_clip_norm=1.0,\n            patience=5,\n            min_delta=0.005,\n            max_new_tokens=64,\n        ):\n            self.model = model\n            self.train_loader = train_loader\n            self.val_loader = val_loader\n            self.tokenizer = tokenizer\n            self.device = device\n            self.output_dir = output_dir\n            self.num_epochs = num_epochs\n            self.grad_clip_norm = grad_clip_norm\n            self.patience = patience\n            self.min_delta = min_delta\n            self.max_new_tokens = max_new_tokens\n\n            os.makedirs(output_dir, exist_ok=True)\n\n            # Restore dropout after single-sample overfit test\n            dropout_rate = model.mt5.config.dropout_rate\n            for m in model.modules():\n                if isinstance(m, nn.Dropout):\n                    m.p = dropout_rate\n\n            # -----------------------------\n            # Freeze strategy: Stage A\n            # -----------------------------\n            # Freeze full encoder\n            for p in model.encoder.parameters():\n                p.requires_grad = False\n\n            # Train decoder cross-attention + output-side params\n            for name, p in model.mt5.named_parameters():\n                p.requires_grad = (\n                    (\"EncDecAttention\" in name) or\n                    (\"lm_head\" in name) or\n                    (\"shared\" in name) or\n                    (\"layer_norm\" in name)\n                )\n\n            # Train encoder->decoder projection if present\n            if model.enc_to_dec_proj is not None:\n                for p in model.enc_to_dec_proj.parameters():\n                    p.requires_grad = True\n\n            trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n            total = sum(p.numel() for p in model.parameters())\n            print(f\"Trainable params: {trainable:,} / {total:,}\")\n\n            # -----------------------------\n            # Optimizer with param groups\n            # -----------------------------\n            bridge_params = []\n            decay_params = []\n            no_decay_params = []\n\n            for name, p in model.named_parameters():\n                if not p.requires_grad:\n                    continue\n\n                if \"enc_to_dec_proj\" in name or \"EncDecAttention\" in name:\n                    bridge_params.append(p)\n                elif any(x in name for x in [\"bias\", \"LayerNorm.weight\", \"layer_norm.weight\", \"shared\"]):\n                    no_decay_params.append(p)\n                else:\n                    decay_params.append(p)\n\n            optimizer_groups = []\n            if bridge_params:\n                optimizer_groups.append({\n                    \"params\": bridge_params,\n                    \"lr\": 2e-4,\n                    \"weight_decay\": 0.01,\n                })\n            if decay_params:\n                optimizer_groups.append({\n                    \"params\": decay_params,\n                    \"lr\": learning_rate,\n                    \"weight_decay\": 0.01,\n                })\n            if no_decay_params:\n                optimizer_groups.append({\n                    \"params\": no_decay_params,\n                    \"lr\": learning_rate,\n                    \"weight_decay\": 0.0,\n                })\n\n            self.optimizer = torch.optim.AdamW(optimizer_groups)\n\n            # CER is the control signal, not val_loss\n            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n                self.optimizer,\n                mode=\"min\",\n                factor=0.5,\n                patience=2,\n                min_lr=1e-6,\n            )\n\n            self.best_cer = float(\"inf\")\n            self.best_wer = float(\"inf\")\n            self.best_checkpoint_path = os.path.join(self.output_dir, \"best_model.pt\")\n            self.training_log = []\n            self.bad_epochs = 0\n\n        # ============================================================\n        # TRAIN\n        # ============================================================\n        def _train_one_epoch(self):\n            self.model.train()\n            total_loss = 0.0\n\n            for batch in self.train_loader:\n                pixel_values = batch[\"pixel_values\"].to(self.device)\n                labels = batch[\"labels\"].to(self.device)\n\n                self.optimizer.zero_grad()\n\n                outputs = self.model(pixel_values=pixel_values, labels=labels)\n                loss = outputs.loss\n                loss.backward()\n\n                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)\n                self.optimizer.step()\n\n                total_loss += loss.item()\n\n            return total_loss / len(self.train_loader)\n\n        # ============================================================\n        # VALIDATE\n        # ============================================================\n        def _validate(self, epoch):\n            self.model.eval()\n            total_loss = 0.0\n            all_predictions = []\n            all_ground_truths = []\n\n            with torch.no_grad():\n                for batch in self.val_loader:\n                    pixel_values = batch[\"pixel_values\"].to(self.device)\n                    labels = batch[\"labels\"].to(self.device)\n\n                    outputs = self.model(pixel_values=pixel_values, labels=labels)\n                    total_loss += outputs.loss.item()\n\n                    generated_ids = self.model.generate(\n                        pixel_values=pixel_values,\n                        max_new_tokens=self.max_new_tokens,\n                        num_beams=1,\n                        do_sample=False,\n                    )\n\n                    clean_labels = labels.clone()\n                    clean_labels[clean_labels == -100] = self.tokenizer.pad_token_id\n\n                    pred_texts = self.tokenizer.batch_decode(\n                        generated_ids, skip_special_tokens=True\n                    )\n                    gt_texts = self.tokenizer.batch_decode(\n                        clean_labels, skip_special_tokens=True\n                    )\n\n                    pred_texts = [normalize_text(x) for x in pred_texts]\n                    gt_texts = [normalize_text(x) for x in gt_texts]\n\n                    all_predictions.extend(pred_texts)\n                    all_ground_truths.extend(gt_texts)\n\n            avg_val_loss = total_loss / len(self.val_loader)\n            cer = self._compute_cer(all_predictions, all_ground_truths)\n            wer = compute_wer_score(all_ground_truths, all_predictions)\n            exact_match = sum(p == g for p, g in zip(all_predictions, all_ground_truths)) / max(1, len(all_predictions))\n\n            return avg_val_loss, cer, wer, exact_match, all_predictions, all_ground_truths\n\n        # ============================================================\n        # CER\n        # ============================================================\n        def _compute_cer(self, predictions, ground_truths):\n            total_edits = 0\n            total_chars = 0\n\n            for pred, gt in zip(predictions, ground_truths):\n                total_edits += self._edit_distance(pred, gt)\n                total_chars += len(gt)\n\n            return 0.0 if total_chars == 0 else total_edits / total_chars\n\n        def _edit_distance(self, s1, s2):\n            m, n = len(s1), len(s2)\n            dp = list(range(n + 1))\n            for i in range(1, m + 1):\n                prev = dp[0]\n                dp[0] = i\n                for j in range(1, n + 1):\n                    temp = dp[j]\n                    if s1[i - 1] == s2[j - 1]:\n                        dp[j] = prev\n                    else:\n                        dp[j] = 1 + min(prev, dp[j], dp[j - 1])\n                    prev = temp\n            return dp[n]\n\n        # ============================================================\n        # CHECKPOINT\n        # ============================================================\n        def _save_checkpoint(self, epoch, is_best=False):\n            checkpoint = {\n                \"epoch\": epoch,\n                \"model_state\": self.model.state_dict(),\n                \"optimizer_state\": self.optimizer.state_dict(),\n                \"scheduler_state\": self.scheduler.state_dict(),\n                \"best_cer\": self.best_cer,\n                \"best_wer\": self.best_wer,\n                \"training_log\": self.training_log,\n            }\n\n            if is_best:\n                torch.save(checkpoint, self.best_checkpoint_path)\n                print(f\"  ✅ Best model saved → {self.best_checkpoint_path}\")\n\n            if epoch % 5 == 0:\n                periodic_path = os.path.join(self.output_dir, f\"checkpoint_epoch_{epoch}.pt\")\n                torch.save(checkpoint, periodic_path)\n                print(f\"  💾 Periodic checkpoint saved → {periodic_path}\")\n\n        # ============================================================\n        # MAIN LOOP\n        # ============================================================\n        def train(self):\n            print(\"=\" * 60)\n            print(f\"Starting training for {self.num_epochs} epochs\")\n            print(\"Generation metrics computed every epoch\")\n            print(\"Primary control metric: CER\")\n            print(\"Target CER for PoC: < 0.50\")\n            print(\"=\" * 60)\n\n            for epoch in range(1, self.num_epochs + 1):\n                avg_train_loss = self._train_one_epoch()\n                avg_val_loss, cer, wer, exact_match, predictions, ground_truths = self._validate(epoch)\n\n                # Scheduler follows CER, not val_loss\n                self.scheduler.step(cer)\n                current_lr = self.optimizer.param_groups[0][\"lr\"]\n\n                log_entry = {\n                    \"epoch\": epoch,\n                    \"train_loss\": avg_train_loss,\n                    \"val_loss\": avg_val_loss,\n                    \"cer\": cer,\n                    \"wer\": wer,\n                    \"exact_match\": exact_match,\n                    \"lr\": current_lr,\n                }\n                self.training_log.append(log_entry)\n\n                print(\n                    f\"Epoch {epoch:>2}/{self.num_epochs} | \"\n                    f\"Train Loss: {avg_train_loss:.4f} | \"\n                    f\"Val Loss: {avg_val_loss:.4f} | \"\n                    f\"CER: {cer:.4f} | \"\n                    f\"WER: {wer:.4f} | \"\n                    f\"EM: {exact_match:.3f} | \"\n                    f\"LR: {current_lr:.2e}\"\n                )\n\n                # Qualitative preview\n                print(\"  Sample predictions:\")\n                indices = random.sample(range(len(predictions)), min(3, len(predictions)))\n                for i in indices:\n                    print(f\"    GT:   '{ground_truths[i]}'\")\n                    print(f\"    Pred: '{predictions[i]}'\")\n\n                improved = cer < (self.best_cer - self.min_delta)\n\n                if improved:\n                    self.best_cer = cer\n                    self.best_wer = wer\n                    self.bad_epochs = 0\n                    self._save_checkpoint(epoch, is_best=True)\n                    print(f\"  🎯 New best CER: {cer:.4f}\")\n                    if cer < 0.50:\n                        print(\"  ✅ TARGET REACHED: CER < 0.50\")\n                else:\n                    self.bad_epochs += 1\n                    print(f\"  No CER improvement. Patience: {self.bad_epochs}/{self.patience}\")\n\n                self._save_checkpoint(epoch, is_best=False)\n\n                if self.bad_epochs >= self.patience:\n                    print(\"  ⏹️ Early stopping triggered.\")\n                    break\n\n            print(\"\\n\" + \"=\" * 60)\n            print(\"Training complete.\")\n            print(f\"Best CER: {self.best_cer:.4f}\")\n            print(f\"Best WER: {self.best_wer:.4f}\")\n\n            if os.path.exists(self.best_checkpoint_path):\n                checkpoint = torch.load(self.best_checkpoint_path, map_location=self.device)\n                self.model.load_state_dict(checkpoint[\"model_state\"])\n                print(\"✅ Restored best checkpoint into model.\")\n\n            return self.training_log\n\n        # ============================================================\n        # OPTIONAL RESUME\n        # ============================================================\n        def load_checkpoint(self, checkpoint_path):\n            checkpoint = torch.load(checkpoint_path, map_location=self.device)\n            self.model.load_state_dict(checkpoint[\"model_state\"])\n            self.optimizer.load_state_dict(checkpoint[\"optimizer_state\"])\n            self.scheduler.load_state_dict(checkpoint[\"scheduler_state\"])\n            self.best_cer = checkpoint[\"best_cer\"]\n            self.best_wer = checkpoint[\"best_wer\"]\n            self.training_log = checkpoint.get(\"training_log\", [])\n            start_epoch = checkpoint[\"epoch\"] + 1\n            print(f\"✅ Resumed from epoch {checkpoint['epoch']} | Best CER so far: {self.best_cer:.4f}\")\n            return start_epoch\n\n\n## Replace Cell 15 with this\n\n\n    trainer = HindiOCRTrainer(\n        model=model,\n        train_loader=train_loader,\n        val_loader=val_loader,\n        tokenizer=tokenizer,\n        device=device,\n        output_dir=\"/content/drive/MyDrive/trocr-checkpoints\",\n        num_epochs=20,\n        learning_rate=1e-4,\n        grad_clip_norm=1.0,\n        patience=5,\n        min_delta=0.005,\n        max_new_tokens=64,\n    )\n\n    training_log = trainer.train()\n\n\n## What changed, and why\n\nThe smallest important changes are these:\n\n  * **CER/WER now run every epoch** , so the best OCR checkpoint cannot be skipped between 5-epoch intervals.\n  * **Scheduler now follows CER** , not validation loss.\n  * **Early stopping** stops the run once CER stops improving.\n  * **Parameter groups** give the bridge and cross-attention a slightly higher LR, while excluding norm and bias-style parameters from weight decay in the usual Hugging Face pattern. (GitHub)\n  * **Text normalization** is applied before metrics, which matters for Unicode-heavy scripts like Devanagari.\n  * The class still keeps your current **frozen encoder + decoder-side adaptation** strategy, which is the right first stage for a warm-started vision-text hybrid where decoder-side cross-attention is the newly learned bridge. (Hugging Face)\n\n\n\n## One small check before you rerun\n\nYour dataset cell still uses `max_length=64` for labels. Before the next run, quickly inspect how many targets are hitting that cap. If a noticeable fraction of your Hindi lines are truncated at 64 tokens, raise it first. Otherwise the trainer can improve while the metric ceiling stays artificially low.\n\nUse this once, anywhere after dataset creation:\n\n\n    lengths = []\n    for i in range(len(train_dataset)):\n        ids = train_dataset[i][\"labels\"]\n        valid = (ids != -100).sum().item()\n        lengths.append(valid)\n\n    print(\"max label tokens:\", max(lengths))\n    print(\"p95 label tokens:\", sorted(lengths)[int(0.95 * len(lengths))])\n    print(\"num hitting 64:\", sum(x >= 64 for x in lengths))\n\n\n## What I would expect after this patch\n\nIf the pipeline is healthy, you should see this pattern:\n\n  * training loss goes down,\n  * validation CER improves in the first several epochs,\n  * the best checkpoint appears **before** the final epoch,\n  * and early stopping restores that checkpoint instead of letting later degradation define the run.\n\n\n\nIf after this patch CER still stays around `1.0` or worse, I would not jump to “different tokenizer” or “different architecture.” I would first test one of these two moves:\n\n  * keep the same freeze strategy and move to **1000+ line crops** , or\n  * keep the same data and switch to **LoRA-style adaptation** on the decoder-side trainable blocks, which is a common low-resource adaptation path in recent Indic OCR work. (ACL Anthology)\n\n",
  "title": "Looking for guidance. Trying to create a model with TrOCR's encoder + Google's mT5 multilingual decoder but model fails to overfit on a single data sample"
}