{
"$type": "site.standard.document",
"bskyPostRef": {
"cid": "bafyreigvtfipekvd5lvbmj2c7sx3sxsn7hvsbil3o67zt2d537pf2jvv3m",
"uri": "at://did:plc:pgryn3ephfd2xgft23qokfzt/app.bsky.feed.post/3mjadt5y7l532"
},
"path": "/t/guide-how-i-debugged-t5-fine-tuning-for-a-medical-diagnosis-task/165573#post_2",
"publishedAt": "2026-04-11T16:29:24.000Z",
"site": "https://discuss.huggingface.co",
"tags": [
"KatharinaJacoby (Katharina) · GitHub"
],
"textContent": "here is the English Version:\nLLM Fine-tuning Debugging Guide: Systematic Problem Solving in Practice\n\n**A complete walkthrough from the first problem to a working medical LLM**\n\n* * *\n\n## Project Goal\n\nDevelop a medical LLM for diagnostic support using T5 fine-tuning.\n\n* * *\n\n## Initial Situation\n\n### Original Code (functional but limited)\n\n\n import pandas as pd\n import transformers\n import torch\n from transformers import T5Tokenizer, T5ForConditionalGeneration\n from datasets import Dataset\n from transformers import DataCollatorForSeq2Seq, Trainer, TrainingArguments\n\n data = [\n {\"input\": \"Symptoms: Fever, cough. CRP: 67. Imaging: Infiltrate basal right. What is the most likely diagnosis?\", \"output\": \"Pneumonia\"},\n {\"input\": \"Symptoms: Dyspnea, left leg swelling. D-Dimer elevated. What is the most likely diagnosis?\", \"output\": \"Pulmonary embolism\"},\n {\"input\": \"Symptoms: Fatigue, pallor. Hb: low. What is the most likely diagnosis?\", \"output\": \"Anemia\"},\n {\"input\": \"Symptoms: Chest pain, high troponin, EKG ST-elevation. What is the most likely diagnosis?\", \"output\": \"Myocardial infarction\"},\n {\"input\": \"Symptoms: Polyuria, polydipsia, blood glucose 320 mg/dl. What is the most likely diagnosis?\", \"output\": \"Diabetes mellitus\"}\n ]\n data = pd.DataFrame(data)\n tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n\n def tokenize(example):\n input_enc = tokenizer(example[\"input\"], truncation=True, padding=\"max_length\", max_length=128)\n output_enc = tokenizer(example[\"output\"], truncation=True, padding=\"max_length\", max_length=32)\n input_enc[\"labels\"] = output_enc[\"input_ids\"]\n return input_enc\n\n dataset = Dataset.from_pandas(data)\n tokenized_dataset = dataset.map(tokenize)\n model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n\n training_args = TrainingArguments(\n output_dir=\"./results\",\n per_device_train_batch_size=2,\n num_train_epochs=20,\n logging_steps=1,\n save_strategy=\"no\",\n report_to=\"none\"\n )\n\n data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n trainer = Trainer(\n model=model,\n args=training_args,\n train_dataset=tokenized_dataset,\n tokenizer=tokenizer,\n data_collator=data_collator\n )\n trainer.train()\n\n def predict(prompt):\n inputs = tokenizer(prompt, return_tensors=\"pt\", padding=True, truncation=True).input_ids\n outputs = model.generate(inputs, max_length=32)\n return tokenizer.decode(outputs[0], skip_special_tokens=True)\n\n test_prompt = \"Symptoms: Shortness of breath, fever, CRP 90, X-ray: Infiltrate right. What is the most likely diagnosis?\"\n print(\"Answer:\", predict(test_prompt))\n\n\n### Initial Results (problematic but functional)\n\n * **Output:** `\"Pneumonia. DD: Pneumonia, Pneumonia\"` (repetitive)\n * **Loss:** 8.78 → 0.43 (very good)\n * **Problem:** Repetitive/incorrect differential diagnoses\n\n\n\n* * *\n\n## Problem Phase 1: Structural Improvement Leads to “True” Bug\n\n### Attempt: Implementing Extended Features\n\n**Goal:** 100 examples, validation split, better output structure\n**Changes:**\n\n * Dataset expanded to 100 examples\n * Structured DD output: `\"Diagnosis: X | DD: Y, Z, W\"`\n * Train/validation split (80/20)\n * `as_target_tokenizer()` → `text_target` (deprecated fix)\n * `tokenizer` → `processing_class` parameter\n\n\n\n### Problem: “True” Bug\n\n\n # Expected: \"Pneumonia. DD: Bronchitis, Pleuritis\"\n # Actual: \"True\"\n\n\n**Symptoms:**\n\n * All outputs only `\"True\"`\n * Model behaves like a binary classifier\n * Missing keys warning: `embed_tokens.weight`, `lm_head.weight`\n\n\n\n* * *\n\n## Debugging Phase 1: Systematic Problem Identification\n\n### Step 1: Parameter Instability Hypothesis\n\n**Observation:** Multiple deprecated/new parameters changed simultaneously\n\n * `evaluation_strategy` → TypeError\n * `processing_class` vs `tokenizer`\n * `text_target` vs `as_target_tokenizer()`\n\n\n\n**Hypothesis:** New parameters are unstable; old parameters work better\n\n### Step 2: Stepwise Rollback\n\n**Strategy:** Change one variable at a time\n\n#### Test 1: `as_target_tokenizer()` Fix\n\n\n # Revert to deprecated but functional method\n with tokenizer.as_target_tokenizer():\n output_enc = tokenizer(example[\"output\"], ...)\n\n\n**Result:** `\"rmelkinese\"` (corrupt, but no longer “True”)\n\n#### Test 2: Original vs Fix Comparison\n\n**Result:** Both times `\"rmelkinese\"` → Problem lies elsewhere\n\n* * *\n\n## Debugging Phase 2: Fresh Environment Strategy\n\n### Step 3: Clean Slate Approach\n\n**Decision:** Fresh notebook, back to functional baseline\n**Baseline Test (5 examples, original code):**\n\n\n # Minimal test for root cause isolation\n data = [original 5 examples without DD]\n\n\n**Result:** `\"What is the most likely diagnosis?\"` (input echo)\n\n* * *\n\n## Debugging Phase 3: Pipeline Diagnosis\n\n### Step 4: Labels Debug\n\n**Check:** Are labels correctly tokenized?\n\n\n print(\"Sample tokenized data:\")\n print(f\"Labels: {tokenized_dataset[0]['labels'][:10]}\")\n print(f\"Decoded Labels: {tokenizer.decode(tokenized_dataset[0]['labels'])}\")\n\n\n**Result:** Labels perfect: `\"Pneumonia</s><pad>...\"`\n\n### Step 5: Attention Mask Debug\n\n**Check:** Does the attention mechanism work?\n\n\n inputs = tokenizer(prompt, return_tensors=\"pt\", padding=True, truncation=True, max_length=128)\n print(f\"Attention mask: {inputs.attention_mask}\")\n print(f\"Attention mask sum: {inputs.attention_mask[0].sum()}\")\n\n\n**Result:** Attention perfect: 36/36 tokens attended\n\n### Step 6: EOS/PAD Token Debug\n\n**Check:** Is token handling correct?\n\n\n print(f\"PAD token: '{tokenizer.pad_token}' -> ID: {tokenizer.pad_token_id}\")\n print(f\"EOS token: '{tokenizer.eos_token}' -> ID: {tokenizer.eos_token_id}\")\n\n\n**Result:** Token setup correct, but generation produces input echo\n\n* * *\n\n## Problem Phase 2: DataCollator Crash\n\n### Step 7: Label-Training-Pipeline Debug\n\n**Deeper Test:** What happens during training?\n**CRASH:**\n\n\n ValueError: Unable to create tensor... Perhaps your features (`input` in this case) have excessive nesting\n\n\n### Root Cause: String Features in Dataset\n\n**Problem:** DataCollator cannot tensorize all features\n\n\n tokenized_dataset.features = {\n \"input\": \"string\", # ❌ DataCollator crash\n \"output\": \"string\", # ❌ DataCollator crash\n \"input_ids\": \"tensor\", # ✅ OK\n \"labels\": \"tensor\" # ✅ OK\n }\n\n\n### Fix: Remove String Features\n\n\n tokenized_dataset = tokenized_dataset.remove_columns([\"input\", \"output\"])\n\n\n**Result:** Training runs, but output still incorrect\n\n* * *\n\n## Debugging Phase 4: T5-Specific Problems\n\n### Step 8: T5 Training Mode Check\n\n**Check:** Does T5 understand our task?\n**Discovery:** T5 has task-specific parameters:\n\n\n model.config.task_specific_params = {\n 'summarization': {'prefix': 'summarize: '},\n 'translation_en_to_de': {'prefix': 'translate English to German: '},\n ...\n }\n\n\n**Problem:** Without task prefix, T5 doesn’t know what to do!\n\n### Step 9: Task Prefix Implementation\n\n\n def tokenize_with_task_prefix(example):\n task_prefixed_input = f\"medical diagnosis: {example['input']}\"\n input_enc = tokenizer(task_prefixed_input, truncation=True, padding=\"max_length\", max_length=128)\n output_enc = tokenizer(example[\"output\"], truncation=True, padding=\"max_length\", max_length=32)\n input_enc[\"labels\"] = output_enc[\"input_ids\"]\n return input_enc\n\n\n**Result:** Input echo stops, but only empty outputs\n\n* * *\n\n## Problem Phase 3: PAD Token Loop\n\n### Step 10: Generation Mechanism Debug\n\n**Problem:** Model generates only PAD tokens `[0,0,0,...]`\n**Deep Debug:**\n\n\n # Raw token analysis\n outputs = model.generate(inputs, max_length=32, do_sample=False)\n print(f\"Raw tokens: {outputs[0]}\")\n # Result: [0, 0, 0, 0, 0, 0, ...]\n\n\n### Hypothesis: Training Volume vs Decoder Mechanism\n\n**Discussion:**\n\n * Are 10 epochs too few for task prefix learning?\n * Or is the decoder-start mechanism broken?\n\n\n\n### Step 11: A/B Test Strategy\n\n**Test 1:** Continue Training (+20 epochs)\n**Test 2:** Fresh Training (30 epochs from scratch)\n\n#### Continue Training Result:\n\n * **Loss:** 2.0 → 0.15-0.30\n * **Output:** `\"Morbus Morbus Morbus...\"` (medical terms, but repetitive)\n\n\n\n#### Fresh Training Result:\n\n * **Loss:** 10.1 → 0.30-0.85\n * **Output:** `\"\"` (empty, PAD tokens)\n\n\n\n**Conclusion:** Continue training is better than fresh!\n\n* * *\n\n## Breakthrough Phase: Generation Parameter Optimization\n\n### Step 12: Improved Generation Parameters\n\n**Problem:** Repetitive output (`\"Morbus Morbus Morbus...\"`)\n**Solution:** Advanced generation parameters\n\n\n def predict_improved(prompt):\n prefixed_prompt = f\"medical diagnosis: {prompt}\"\n inputs = tokenizer(prefixed_prompt, return_tensors=\"pt\", padding=True, truncation=True)\n\n outputs = model.generate(\n input_ids=inputs.input_ids,\n attention_mask=inputs.attention_mask,\n max_new_tokens=32,\n repetition_penalty=2.0, # ← Anti-repetition\n num_beams=4, # ← Better quality\n early_stopping=True, # ← Stop at EOS\n eos_token_id=tokenizer.eos_token_id\n )\n return tokenizer.decode(outputs[0], skip_special_tokens=True)\n\n\n### Breakthrough Results:\n\n * **Input:** `\"Symptoms: Shortness of breath, fever, CRP 90...\"`\n * **Output:** `\"Shortness of breath, fever, CRP 90, X-ray\"`\n**Analysis:** Model extracts relevant medical information, but no diagnosis yet!\n\n\n\n* * *\n\n## Final Success Phase: Scale & Training Optimization\n\n### Step 13: Dataset & Training Scale-Up\n\n**Strategy:** More data + intensive training\n**Scaling:**\n\n * **25 → 160 examples** (6x more data)\n * **30 → 40 epochs** (more training)\n * **19 medical specialties** covered\n\n\n\n**Optimized Training Parameters:**\n\n\n training_args = TrainingArguments(\n output_dir=\"./results\",\n per_device_train_batch_size=4, # Larger batches\n num_train_epochs=40, # More epochs\n learning_rate=3e-4, # Optimized LR\n warmup_steps=50, # Warmup for stability\n logging_steps=10,\n save_strategy=\"no\",\n report_to=\"none\"\n )\n\n\n### Final Training Results:\n\n * **Loss:** 9.9 → **0.009** (Outstanding!)\n * **160 examples** successfully trained\n * **40 epochs** with perfect convergence\n\n\n\n* * *\n\n## SUCCESS: Functional Medical LLM\n\n### Final Test Results (100% Success Rate):\n\nFinal Test Results\n\nTest | Input | Generated | Expected | Status\n---|---|---|---|---\n1 | Fever, cough, infiltrate | **Pneumonia** | Pneumonia |\n2 | Chest pain, troponin, ST-elevation | **Myocardial infarction** | Myocardial infarction |\n3 | Polyuria, blood glucose 320 mg/dl | **Diabetes mellitus** | Diabetes mellitus |\n4 | Tremor, rigidity, bradykinesia | **Parkinson’s disease** | Parkinson’s disease |\n5 | Headache, meningismus | **Meningitis** | Meningitis |\n\n* * *\n\n## Debugging Steps Summary\n\n### Systematic Problem Identification\n\n 1. **Parameter Instability Analysis**\n * Cross-pattern recognition between various deprecated warnings\n * Isolation of individual parameter changes\n 2. **Pipeline Component Test**\n * Labels tokenization\n * Attention mask\n * EOS/PAD token handling\n * DataCollator → **FIXED**\n 3. **T5-Specific Requirements**\n * Task prefix requirement identified\n * Encoder-decoder pipeline understood\n 4. **Generation Mechanism Optimization**\n * Parameter tuning for anti-repetition\n * Beam search for better quality\n 5. **Scale & Training Optimization**\n * Dataset size as a critical factor\n * Training volume for complex tasks\n\n\n\n* * *\n\n## Final Code Solution\n\n\n import pandas as pd\n import transformers\n import torch\n from transformers import T5Tokenizer, T5ForConditionalGeneration\n from datasets import Dataset\n from transformers import DataCollatorForSeq2Seq, Trainer, TrainingArguments\n\n # LARGE DATABASE: 160 medical examples\n data = [\n # ... [160 examples from 19 specialties]\n ]\n\n tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n\n # T5 TASK PREFIX (critical for T5 performance)\n def tokenize_with_task_prefix(example):\n task_prefixed_input = f\"medical diagnosis: {example['input']}\"\n input_enc = tokenizer(task_prefixed_input, truncation=True, padding=\"max_length\", max_length=128)\n output_enc = tokenizer(example[\"output\"], truncation=True, padding=\"max_length\", max_length=32)\n input_enc[\"labels\"] = output_enc[\"input_ids\"]\n return input_enc\n\n dataset = Dataset.from_pandas(data)\n tokenized_dataset = dataset.map(tokenize_with_task_prefix)\n\n # DATACOLLATOR FIX: Remove string features\n tokenized_dataset = tokenized_dataset.remove_columns([\"input\", \"output\"])\n\n model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n\n # OPTIMIZED TRAINING PARAMETERS\n training_args = TrainingArguments(\n output_dir=\"./results\",\n per_device_train_batch_size=4,\n num_train_epochs=40,\n learning_rate=3e-4,\n warmup_steps=50,\n logging_steps=10,\n save_strategy=\"no\",\n report_to=\"none\"\n )\n\n data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n trainer = Trainer(\n model=model,\n args=training_args,\n train_dataset=tokenized_dataset,\n tokenizer=tokenizer,\n data_collator=data_collator\n )\n trainer.train()\n\n # OPTIMIZED PREDICTION FUNCTION\n def predict_medical_diagnosis(prompt):\n prefixed_prompt = f\"medical diagnosis: {prompt}\"\n inputs = tokenizer(prefixed_prompt, return_tensors=\"pt\", padding=True, truncation=True)\n\n outputs = model.generate(\n input_ids=inputs.input_ids,\n attention_mask=inputs.attention_mask,\n max_new_tokens=32,\n repetition_penalty=2.0, # Anti-repetition\n num_beams=4, # Better quality\n early_stopping=True, # Stop at EOS\n eos_token_id=tokenizer.eos_token_id\n )\n return tokenizer.decode(outputs[0], skip_special_tokens=True)\n\n # TEST\n test_prompt = \"Symptoms: Shortness of breath, fever, CRP 90, X-ray: Infiltrate right. What is the most likely diagnosis?\"\n result = predict_medical_diagnosis(test_prompt)\n print(f\"Diagnosis: {result}\") # Output: \"Pneumonia\"\n\n\n* * *\n\n## Critical Success Factors\n\n### Must-Have Components:\n\n 1. **T5 Task Prefix:** `\"medical diagnosis: \"` - Essential for T5 understanding\n 2. **DataCollator Fix:** Remove string features\n 3. **Sufficient Data:** At least 100+ examples for complex mappings\n 4. **Advanced Generation:** Repetition penalty, beam search, early stopping\n 5. **Training Volume:** 40+ epochs for task learning\n\n\n\n### Common Pitfalls:\n\n 1. **Deprecated Parameters:** New APIs not always more stable\n 2. **Fresh vs Continue:** Continue training can be better than fresh\n 3. **Cache/Memory Issues:** Fresh environment solves many problems\n 4. **Generation Parameters:** Default parameters often insufficient\n 5. **Dataset Size:** Too small datasets lead to overfitting/repetition\n\n\n\n* * *\n\n## Debugging Strategies (Lessons Learned)\n\n### 1. Systematic Isolation\n\n * **Change one variable at a time**\n * **Start from a working baseline**\n * **Forward debugging, not backward guessing**\n\n\n\n### 2. Pipeline-Oriented Diagnosis\n\n\n Input → Tokenization → Attention → Training → Generation → Output\n ✅ ✅ ✅ ❌ ❌ ❌\n\n\n**Test each step individually**\n\n### 3. Fresh Environment as a Debugging Tool\n\n * **Eliminate cache/memory issues**\n * **Enable clean state for reproducible tests**\n * **Allow controlled experiments**\n\n\n\n### 4. Recognize Parameter Instability\n\n * **Take deprecated warnings seriously**\n * **Cross-pattern recognition between different errors**\n * **Choose conservative parameters when in doubt**\n\n\n\n### 5. Understand Model-Specific Requirements\n\n * **T5 needs task prefix for new tasks**\n * **Encoder-decoder models have special requirements**\n * **Generation parameters are critical for output quality**\n\n\n\n* * *\n\n## Final Insights\n\n### What Worked:\n\n 1. **Emergency medicine debugging principles → ML engineering**\n 2. **Systematic differential diagnosis → Bug isolation**\n 3. **“Better safe than sorry” → Conservative development**\n 4. **Fresh environment strategy → Clean testing**\n 5. **Cross-pattern recognition → Root cause analysis**\n\n\n\n### Performance Metrics:\n\n * **Training Loss:** 9.9 → 0.009 (99.9% improvement)\n * **Test Accuracy:** 100% on 5 different medical cases\n * **Specialty Coverage:** 19 medical specialties\n * **Debugging Time:** ~3 hours of systematic analysis\n\n\n\n* * *\n\n## Next Development Steps\n\n### Possible Extensions:\n\n 1. **Add differential diagnoses**\n 2. **Implement confidence scoring**\n 3. **Validation set for overfitting prevention**\n 4. **Larger model (T5-base/large) for complex cases**\n 5. **Real-world medical data integration**\n\n\n\n### Deployment Considerations:\n\n * **Model versioning for different specialties**\n * **API wrapper for clinical integration**\n * **Safety measures for medical applications**\n * **Continuous learning from new cases**\n\n\n\n* * *\n\n## Key Takeaways for ML Engineering\n\n### 1. Debugging is a systematic process\n\n**Don’t guess, test methodically**\n\n### 2. Domain Knowledge + Technical Skills = Success\n\n**Medical expertise + ML engineering = Powerful combination**\n\n### 3. Fresh Environment is a Powerful Tool\n\n**“Turn it off and on again” works for ML too**\n\n### 4. Conservative Parameter Choice Pays Off\n\n**Old, stable parameters > new, unstable parameters**\n\n### 5. Model-Specific Requirements are Critical\n\n**T5, BERT, GPT each have different best practices**\n\n* * *\n\n## Project Success\n\n**From a non-functional “True” bug to a 100% accurate medical LLM in systematic debugging steps.**\n\n**Proof: Systematic approach + domain expertise + technical implementation = Successful ML solution**\n\n**Final loss: 0.009; accuracy on our small, hand-crafted test set was 100%, likely due to the limited dataset and clear-cut labels (e.g., Pneumonia, Myocardial infarction). This should not be interpreted as clinical performance.**\n\n* * *\n\n*This document shows how real ML problems are solved in practice: Not by luck or intuition, but by systematic analysis, methodical testing, and step-by-step problem solving.\n\nFor more info check out my GitHub repos* KatharinaJacoby (Katharina) · GitHub",
"title": "[Guide] How I debugged T5 fine-tuning for a medical diagnosis task"
}