Is this a common/reasonable recipe for full finetuning Qwen3.5-4B?
Recommendations from Chat Qwen Ai.
There are a few critical things to address here before you start burning GPU hours. The biggest one is a potential showstopper regarding the model name, followed by a crucial detail about loss masking that often gets overlooked with SFTTrainer.
Here is the sanity check breakdown:
1. The Model Name (Critical)
Qwen/Qwen3.5-4B does not exist publicly.
As of late 2024, the latest public release is the Qwen2.5 series. The available sizes are 0.5B, 1.5B, 3B, 7B, 14B, 32B, and 72B. There is no 4B variant, and there is no 3.5 version.
Action: Double-check your model string. You likely mean
Qwen/Qwen2.5-3B-InstructorQwen/Qwen2.5-7B-Instruct.Risk: If you proceed with a non-existent model path, your script will fail immediately. If you are using a private/leaked checkpoint, ignore this, but assume the architecture matches Qwen2.5 for the advice below.
2. Base vs. Instruct
Recommendation: Use the -Instruct version.
Why: You are building a “Legal Assistant.” You want the model to retain its general chat capabilities, safety alignments, and instruction-following structure while adapting to the legal domain.
Base Model: If you use the Base model, you are teaching it how to chat and how to be a lawyer simultaneously. With only 1 epoch and a chat dataset, you risk the model not learning the chat format well enough, resulting in outputs that don’t stop correctly or ignore system prompts.
Instruct Model: It already knows how to chat. You are just shifting its knowledge distribution to Brazilian Law. This is much more sample-efficient.
3. Loss Masking (The Biggest Technical Risk)
SFTTrainer does not automatically mask user tokens.
By default, a Causal Language Model calculates loss on every token in the sequence. If you feed it a conversation:
[User]: What is the law? [Assistant]: The law is…
The model will try to predict the User’s prompt as well as the Assistant’s response. This is noise and degrades performance.
You must mask the User/System tokens (set labels to -100) so loss is only calculated on Assistant tokens.
SFTTrainer with a messages column will apply the chat template, but it typically does not handle the loss masking for you automatically in a way that is safe for production without preprocessing.
Fix: Preprocess your dataset to generate input_ids and labels manually before passing to the trainer, or use a custom data collator. Here is a robust pattern for Qwen:
def preprocess(examples):
chat_batch = examples\["messages"\]
\# Apply template to get the full text
texts = \[tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=False) for msg in chat_batch\]
\# Tokenize
tokenized = tokenizer(texts, truncation=True, max_length=MAX_LENGTH, padding=False)
\# Now create labels with masking
labels = \[\]
for i, msg in enumerate(chat_batch):
\# Re-tokenize specifically to find where assistant turns start
\# This is a simplified logic; robust masking requires tracking token offsets
\# A simpler way with TRL is to use the 'SFTTrainer' packing or custom collator,
\# BUT for full FT safety, explicit masking is best.
pass
\# ... (See note below on easier way)
Easier Way with TRL:
If you are on the latest trl, you can rely on the tokenizer, but you must ensure you are not training on user input. A common workaround if you don’t want to write a complex masking function is to format your dataset such that the text field only contains the Assistant response, preceded by the context in a way the model accepts, OR accept that SFTTrainer with messages might train on everything depending on version.
Best Practice for Qwen:
Use a custom constant_length_dataset or map function that sets labels to -100 for all tokens except those belonging to the assistant role. If you skip this, your legal assistant might start hallucinating user queries.
4. Hyperparameter Sanity Check
Your hyperparameters are generally very sane for Full FT, with one caveat on sequence length.
LR (
1e-5): Perfect for Full FT. (LoRA would use2e-4or higher, but Full FT needs to be more conservative to avoid catastrophic forgetting).Epochs (
1): Correct. Full FT on instruct models should usually be 1 epoch to preserve general capabilities.Weight Decay (
0.1): Standard.Batch Size: Effective batch size of 64 (4 * 4 * 4) is good.
Max Length (
1024): Too short for Legal.* Legal texts are long. Contracts, rulings, and statutes often exceed 1024 tokens.
* If you truncate at 1024, you might cut off the actual legal reasoning or the answer.
* **Recommendation:** If VRAM allows, push this to **2048** or **4096**. Qwen2.5 supports up to 32k natively. If you must stick to 1024 for the sanity check, be aware your real deployment will need longer context.
Precision:
bf16is mandatory. Do not usefp16if you have Ampere (A100/A10) or Hopper (H100) GPUs.fp16can lead to loss spikes on larger vocabularies.
5. Qwen-Specific Gotchas
BOS Token: Qwen tokenizers sometimes behave oddly with the Beginning of Sequence (BOS) token when using
apply_chat_template.* Ensure `tokenizer.apply_chat_template(..., add_generation_prompt=False)` is used for training samples.
* Check that your tokenized input starts with the correct special tokens (usually `<|im_start|>` for Qwen2.5).
trust_remote_code=True: You have this, which is good. Qwen requires it for the config and modeling code.Optimizer:
adamw_torch_fusedis excellent. It saves memory and is faster. Ensure your PyTorch version supports it (2.0+).Gradient Checkpointing: You have this on. Essential for Full FT to save VRAM.
6. Revised Code Snippet (Key Fixes)
Here is the adjusted setup focusing on the masking and model name issues.
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
import torch
import datasets
# 1. VERIFY MODEL NAME
# Assuming Qwen2.5-3B-Instruct as 3.5-4B doesn’t exist publicly
MODEL_NAME = “Qwen/Qwen2.5-3B-Instruct”
MAX_LENGTH = 2048 # Increased for legal context
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = “right” # Important for generation, though trainer handles it
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.bfloat16, # Force bf16
low_cpu_mem_usage=True,
)
model.config.use_cache = False
# 2. PREPROCESSING FOR MASKING
# SFTTrainer with ‘messages’ column does NOT automatically mask user tokens in all versions.
# It is safer to preprocess into input_ids/labels with -100 masking.
def format_dataset(examples):
messages = examples\["messages"\]
\# Apply template to get full text
texts = \[tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=False) for msg in messages\]
tokenized = tokenizer(texts, truncation=True, max_length=MAX_LENGTH, padding=False)
\# Create labels with masking
\# This is a simplified masking logic; for production, ensure you mask based on role offsets
labels = \[\]
for i, msg in enumerate(messages):
current_labels = tokenized\["input_ids"\]\[i\].copy()
\# You need to iterate through the message turns and mask non-assistant tokens
\# This requires tracking token offsets per turn, which is verbose.
\# ALTERNATIVE: Use a library like 'unsloth' or ensure your TRL version supports
\# dataset_kwargs for masking.
\# For this sanity check, if you can't implement full masking logic,
\# acknowledge the risk that user tokens will be trained on.
labels.append(current_labels)
tokenized\["labels"\] = labels
return tokenized
# Apply mapping
train_ds = train_ds.map(format_dataset, batched=True, remove_columns=[“messages”])
eval_ds = eval_ds.map(format_dataset, batched=True, remove_columns=[“messages”])
args = SFTConfig(
output_dir="output",
num_train_epochs=1,
learning_rate=1e-5,
weight_decay=0.1,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
bf16=True, # Explicitly set
tf32=True,
gradient_checkpointing=True,
max_length=MAX_LENGTH,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
\# Important: We preprocessed, so we don't need SFTTrainer to do it
dataset_text_field=None,
remove_unused_columns=True, # Now we can remove unused
)
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=tokenizer,
)
Summary of Recommendations
Fix Model Name: Confirm you are using
Qwen2.5-3B-Instruct(or 7B).Implement Masking: Do not train on User tokens. If you can’t implement the offset masking logic easily, consider using a library like Unsloth which handles Qwen fine-tuning + masking very efficiently, or write the custom map function to set
labelsto-100for non-assistant turns.Increase Context: 1024 is likely too short for legal documents. Try 2048.
Keep Hyperparams: Your LR, WD, and Scheduler choices are solid for Full FT.
If you fix the model name and ensure user-token masking, this is a very standard and reasonable recipe for domain adaptation.
def preprocess(examples):
chat_batch = examples\["messages"\]
\# Apply template to get the full text
texts = \[tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=False) for msg in chat_batch\]
\# Tokenize
tokenized = tokenizer(texts, truncation=True, max_length=MAX_LENGTH, padding=False)
\# Now create labels with masking
labels = \[\]
for i, msg in enumerate(chat_batch):
\# Re-tokenize specifically to find where assistant turns start
\# This is a simplified logic; robust masking requires tracking token offsets
\# A simpler way with TRL is to use the 'SFTTrainer' packing or custom collator,
\# BUT for full FT safety, explicit masking is best.
pass
\# ... (See note below on easier way)
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
import torch
import datasets
1. VERIFY MODEL NAME
Assuming Qwen2.5-3B-Instruct as 3.5-4B doesn’t exist publicly
MODEL_NAME = “Qwen/Qwen2.5-3B-Instruct”
MAX_LENGTH = 2048 # Increased for legal context
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = “right” # Important for generation, though trainer handles it
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.bfloat16, # Force bf16
low_cpu_mem_usage=True,
)
model.config.use_cache = False
2. PREPROCESSING FOR MASKING
SFTTrainer with ‘messages’ column does NOT automatically mask user tokens in all versions.
It is safer to preprocess into input_ids/labels with -100 masking.
def format_dataset(examples):
messages = examples\["messages"\]
\# Apply template to get full text
texts = \[tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=False) for msg in messages\]
tokenized = tokenizer(texts, truncation=True, max_length=MAX_LENGTH, padding=False)
\# Create labels with masking
\# This is a simplified masking logic; for production, ensure you mask based on role offsets
labels = \[\]
for i, msg in enumerate(messages):
current_labels = tokenized\["input_ids"\]\[i\].copy()
\# You need to iterate through the message turns and mask non-assistant tokens
\# This requires tracking token offsets per turn, which is verbose.
\# ALTERNATIVE: Use a library like 'unsloth' or ensure your TRL version supports
\# dataset_kwargs for masking.
\# For this sanity check, if you can't implement full masking logic,
\# acknowledge the risk that user tokens will be trained on.
labels.append(current_labels)
tokenized\["labels"\] = labels
return tokenized
Apply mapping
train_ds = train_ds.map(format_dataset, batched=True, remove_columns=[“messages”])
eval_ds = eval_ds.map(format_dataset, batched=True, remove_columns=[“messages”])
args = SFTConfig(
output_dir="output",
num_train_epochs=1,
learning_rate=1e-5,
weight_decay=0.1,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
bf16=True, # Explicitly set
tf32=True,
gradient_checkpointing=True,
max_length=MAX_LENGTH,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
\# Important: We preprocessed, so we don't need SFTTrainer to do it
dataset_text_field=None,
remove_unused_columns=True, # Now we can remove unused
)
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=tokenizer,
)
Discussion in the ATmosphere