Would this concept model work?
Looks like there might be a slight bug?
I reviewed the actual code path, not the README. I also sanity-checked a tiny forward/backward path locally.
The short verdict is:
This is a real model implementation, not a fake scaffold. It can plausibly train into a coherent prototype. I would not launch the 1B run unchanged. The main reasons are not “ternary is impossible” or “MDLM is wrong.” The core architecture is aligned with the literature: masked diffusion language models are viable, block diffusion is a real semi-autoregressive extension with KV caching, ternary-from-scratch has precedent, and BitNet a4.8-style hybrid activation handling is the right direction. The fragile zone is still low-bit attention/activation behavior, especially when stacked on diffusion. (arXiv)
Final judgment
If you asked me, “Would this codebase probably produce a coherent 1B masked-diffusion model if I spend the compute?”, my answer is:
Probably yes, after a few fixes.
If you asked me, “Would this exact codebase, as-is, cleanly validate the whole concept and be easy to trust at 1B/40B?”, my answer is:
No. It has a solid core plus several correctness and interpretation issues.
What is solid
These parts are good enough that I would keep them.
1. The core modeling choice is valid
The model is a bidirectional denoiser with absorbing-state masking and a per-sample noise level t. That is the right family for MDLM-style training. MDLM specifically showed that simple masked discrete diffusion can be much stronger than older diffusion-for-text setups and can support efficient samplers. (arXiv)
2. The ternary-weight implementation is conceptually sound
The code keeps latent full-precision weights and uses STE ternary quantization in the forward pass. That is the standard kind of construction you would expect from BitNet-style training. The overall idea of native ternary weights is supported by BitNet b1.58. (arXiv)
3. The A8 → A4 schedule is the right instinct
This is one of the best choices in the repo. BitNet a4.8 is not “all 4-bit everywhere from the first step.” It is selective and hybrid. Your code is directionally aligned with that. (arXiv)
4. The block sampler has the right basic idea
Your block sampler uses committed context plus an ephemeral current block. That is a sensible prototype for block diffusion. The public block-diffusion work explicitly motivates arbitrary-length generation, KV caching, and parallel token sampling in exactly this general direction. (arXiv)
Must-fix before a 1B run
These are the items I would treat as hard blockers or near-blockers.
1. MaskDiffusionLoss can return NaN
This is the most important correctness bug I found.
The loss sets all non-supervised positions to ignore_index and then calls F.cross_entropy. If there are zero supervised positions in a batch, PyTorch returns NaN. I verified this locally with a tiny smoke test.
Why this matters:
- with very long sequences, it is rare that no positions are masked,
- but it is still a real edge case,
- and thinking-token exclusion makes it easier for “all masked positions are excluded from loss” to happen on small or special batches.
Fix:
- before calling
cross_entropy, checkif not mask_flat.any(): return logits.new_zeros(()).
Without this fix, rare NaNs can poison long training runs.
2. The variable-length curriculum is mostly canceled by the dataloader
Your data-prep script creates variable-length chunks. But StreamingJsonlDataset re-tokenizes the stored "text" and appends everything into one token buffer, then emits fixed max_length chunks.
So end to end, the effective training stream is mostly fixed-length re-chunked windows , not the intended weighted length distribution.
Why this matters:
- your experiments become harder to interpret,
- your training is less like the intended curriculum than you think,
- if you believe shorter and longer contexts are both important, the current pipeline largely throws that away.
Fix:
- store tokenized chunks directly, or
- keep one JSONL line = one training example, do not flatten the entire corpus back into a global rolling token buffer.
3. attention_mask is built, then ignored
The collator returns attention_mask, but the training loop only uses batch["input_ids"]. The model forward path also has no attention-mask input.
Today this is partly hidden because the dataset mostly emits full-length chunks. But partial chunks still exist, and if you later restore true variable-length batching, this becomes a serious issue.
Why this matters:
- padded positions can enter the corruption process,
- padded positions can contribute to attention,
- pad token is set to
eos_tokenif missing, so the model can learn from EOS-padding artifacts.
Fix:
- propagate
attention_maskinto masking and loss, - exclude padded positions from
apply_mask, - exclude them from supervised loss,
- ideally add real attention masking if you want genuine variable-length batches.
4. BlockDiffusionSampler.generate() is wrong for num_samples > 1
This is a real logic bug.
The block sampler accumulates all_generated and block_texts from block_ids[0], then uses that same shared buffer when returning results for all samples. So if num_samples > 1, the returned outputs are effectively copies of sample 0.
Fix:
- keep
all_generatedper sample, not once globally, - keep
block_textsper sample too.
If you only ever sample one output at a time, this does not hurt you. But it is still a bug.
5. generate_sample() can sample special tokens and silently turn them into the last vocabulary token
In the training monitor sampler, you sample from the full logit tensor, not just the normal vocabulary slice, and then at the end clamp IDs into [0, vocab_size - 1].
That means:
- if the model samples the mask token or think token,
- the code silently converts it to the last normal token ID.
This does not corrupt training directly. It corrupts your qualitative monitoring and makes samples less trustworthy.
Fix:
- slice logits to
:vocab_sizebefore sampling, like your other samplers already do.
Should-fix
These are not guaranteed failures, but they weaken the experiment.
1. Thinking tokens are under-supervised
In code terms, think positions are excluded from the direct supervised loss. They only receive gradient indirectly through answer quality.
That can work as an experimental latent-variable trick. But it is weak supervision.
My expectation:
- maybe helpful,
- maybe ignored,
- maybe unstable if you over-interpret it as “reasoning.”
For a first serious 1B run, I would either:
- disable thinking tokens, or
- keep only the simplest global-prefix version and remove per-block thinking.
2. Per-block thinking at inference does not match training
Training prepends one think prefix to the sequence. Block sampling can prepend think tokens before every block.
That is a train-test mismatch.
It may still “work” in the loose sense that the model produces something. But if the feature matters at all, this mismatch makes the result harder to trust.
3. The KV quantization scheme is simpler than the strongest public guidance
Your active cache path uses a simple per-head absmax quantizer for both keys and values.
KIVI’s main conclusion is that keys and values do not want the same treatment: keys work better with per-channel quantization, values with per-token quantization. So your cache may still work, but it is not using the best-supported asymmetry yet. (arXiv)
4. The full-sequence denoiser’s KV cache buys almost nothing
In the non-block sampler, you reset the KV cache every denoising step and re-run the full sequence. That is logically correct because the mask pattern changes every step, but it also means the cache is not giving you a real inference win there.
That is not a bug. It just means:
- KV cache matters mainly for your block sampler ,
- not for the full denoiser.
5. The default run is not actually 40B tokens
The training config computes to about 30.1B tokens , not 40B.
That is not a correctness problem. It is a planning problem. If you want a 40B-token run, your step count needs to change.
Fine for now
These are not where I would spend time first.
1. No causal mask in attention
Correct for diffusion.
2. Latent full-precision weights with STE
Standard for this kind of research implementation.
3. MoE code
Not the current concern because it is off by default.
4. RoPE offset handling in block mode
Directionally correct and useful for committed-context generation.
What I think will happen if you run it unchanged
Most likely:
- it does train ,
- it gives coherent outputs,
- the ternary core is not the main reason it fails,
- the A8 → A4 schedule probably helps rather than hurts,
- but the final result is harder to interpret because the data pipeline and thinking-token behavior are not clean.
The most likely disappointments are:
- weaker-than-expected gains from low-bit activations,
- unclear value from thinking tokens,
- KV-cache quality below what you would hope from the best papers,
- and results that are noisier than they need to be because of the data path and edge-case loss behavior. Recent work on 4-bit attention explicitly says attention is the main obstacle because of heavy-tailed activations and precision-mismatch instability, which matches where I would expect your run to be most fragile. (arXiv)
What I think will happen if you fix the blockers
Then I think the code has a real chance to produce a meaningful 1B prototype.
Not “state of the art.” Not “obviously better than a same-budget AR baseline.” But a real prototype that demonstrates:
- masked diffusion training,
- ternary-weight viability,
- staged hybrid activation quantization,
- and blockwise semi-autoregressive generation.
That is a legitimate target. MDLM supports the masked-diffusion backbone, block diffusion supports the blockwise generation idea, BitNet b1.58 supports native ternary weights, and BitNet a4.8 supports the general hybrid A8/A4 direction plus low-bit KV as an inference concept. (arXiv)
My recommendation
Before spending serious compute on 1B, I would do exactly this:
Fix
MaskDiffusionLossfor the zero-supervised-token case.Fix multi-sample block generation.
Fix the qualitative sampler so it cannot turn special tokens into fake normal tokens.
Decide whether you want:
- real variable-length training, then preserve it end to end and use masks properly, or
- fixed-length training, then simplify the pipeline and stop pretending otherwise.
Disable thinking tokens for the first real 1B run.
Treat the current KV cache as a prototype cache, not a final serving recipe.
Bottom line
Strict code-review answer:
The codebase is structurally real and probably trainable. It is not clean enough yet for an unquestioned 1B run. The main risks are correctness and experiment-interpretation risks, not “the whole concept is impossible.”
Here is the ranked patch plan I would use for your repo.
The ordering is based on one question only: what most reduces the chance of wasting a 1B / 40B-token run. The literature says your backbone choice is plausible: MDLM-style masked diffusion is a real language-modeling family, block diffusion is a real semi-autoregressive extension with KV reuse, ternary-from-scratch has precedent, and BitNet a4.8 supports the general idea of staged hybrid low-bit activations. The main fragility zone remains low-bit attention/activation behavior , not the existence of the overall concept. (arXiv)
Tier 0: patch before any serious 1B run
1. Make MaskDiffusionLoss safe when there are zero supervised positions
Files: bitdiffusion/diffusion.py
Why this is first
This is the only issue I found that can directly produce a silent training poison. I locally verified that your current loss returns NaN when every position is ignored.
In your code, the loss:
- flattens logits and targets,
- masks out non-supervised positions,
- writes
ignore_indexinto all other targets, - then calls
F.cross_entropy(...).
If all positions are ignored, PyTorch returns NaN.
Patch
Add a guard right before cross_entropy:
if not mask_flat.any():
return logits.new_zeros(())
Why it matters for your concept
Diffusion training already has noisier supervision than plain next-token prediction because the supervised set changes each batch. Block diffusion adds more schedule sensitivity, and low-bit training leaves less numerical slack. A rare NaN is much more dangerous in this regime than in a boring baseline. The block-diffusion paper explicitly highlights variance control and noise scheduling as first-class engineering concerns, and Attn-QAT shows that low-bit attention is already the main stability bottleneck. (arXiv)
Minimal test
- unit test with
is_masked = torch.zeros(...) - assert loss is finite and exactly zero
2. Fix multi-sample block generation
Files: bitdiffusion/sample.py
What is wrong
In BlockwiseDiffusionSampler.generate(), all_generated and block_texts are single shared Python lists, but the method returns one result per sample. The code collects tokens from block_ids[0] only, then reuses that same accumulated sequence for every sample.
So num_samples > 1 is currently wrong.
Patch
Change:
all_generated: list[int] = []block_texts: list[str] = []
to per-sample structures, for example:
all_generated = [[] for _ in range(num_samples)]
block_texts = [[] for _ in range(num_samples)]
Then collect and decode per sample.
Why it matters
This does not break single-sample runs. But it makes batched sampling misleading, which is bad for evaluating diversity and sampler correctness. Since MDLM and block diffusion are often judged partly on generation behavior, broken multi-sample output makes the model look more deterministic or cleaner than it really is. (arXiv)
Minimal test
- run
num_samples=2with a fixed seed and temperature > 0 - assert outputs are independently tracked
- assert internal block text lists differ when token traces differ
3. Fix generate_sample() so it cannot sample special tokens and silently map them to normal tokens
Files: bitdiffusion/train.py
What is wrong
Your qualitative monitor sampler samples from the full output vocabulary, then later clamps token IDs to vocab_size - 1. If the model samples the mask token or think token, that special token gets silently turned into the last normal vocabulary token.
So your training samples can look cleaner or stranger for the wrong reason.
Patch
Change:
probs = torch.softmax(logits / temperature, dim=-1)
to:
probs = torch.softmax(logits[:, :, :model.config.vocab_size] / temperature, dim=-1)
Do not rely on post-hoc clamping.
Why it matters
This does not directly affect training, but it absolutely affects whether you trust your monitoring. In diffusion models, qualitative inspection is important because loss curves alone do not tell the whole story about generation quality. (arXiv)
Minimal test
- force logits to favor mask token
- assert sampler never returns an out-of-range or silently remapped normal token
4. Decide whether you want true variable-length training or fixed-length training, then make the code match
Files: prepare_hf_jsonl.py, bitdiffusion/data.py, bitdiffusion/train.py
What is wrong
Your prep script creates a variable-length curriculum. Then the dataset loader re-tokenizes each "text" field, concatenates everything into a rolling token buffer, and emits fixed max_length chunks. So the end-to-end training stream is mostly fixed-length again.
Patch choice A: keep variable-length training
- store tokenized examples directly
- keep one JSONL example = one training example
- use
attention_maskthroughout masking and loss - do not re-flatten the corpus into a global rolling token buffer
Patch choice B: admit fixed-length training
- simplify prep
- stop generating variable-length chunks upstream
- keep fixed-length windows deliberately
My recommendation
For your first 1B run, choose B unless variable-length behavior is central to your research question. Fixed-length training is simpler and easier to debug.
Why it matters
Block diffusion papers emphasize variance and schedule quality. If your intended curriculum is being erased by the loader, you do not really know what you trained. Clean experimental semantics matter more here than fancy preprocessing. (arXiv)
Minimal test
- inspect a batch length histogram after collation
- confirm it matches what you think the loader is doing
Tier 1: fix before spending the full 40B tokens
5. Propagate attention_mask into corruption and loss, or remove padding entirely
Files: bitdiffusion/data.py, bitdiffusion/train.py, bitdiffusion/diffusion.py, bitdiffusion/model.py
What is wrong
The collator builds attention_mask. The training loop ignores it. The model forward path also ignores it.
Right now this is partly masked by your fixed-length behavior. But the moment you preserve variable lengths, padded positions become real positions for:
- masking,
- attention,
- loss bookkeeping.
And because pad defaults to EOS if the tokenizer lacks a pad token, the model can learn EOS-padding artifacts.
Patch
At minimum:
- exclude padded positions from
apply_mask - exclude padded positions from
MaskDiffusionLoss
If you later restore true variable-length batching:
- also pass an attention mask into attention
Why it matters
This is less urgent than the NaN fix because your current loader mostly emits full chunks. But once you want honest variable-length behavior, this becomes a correctness issue, not a cleanup. (arXiv)
6. Disable thinking tokens for the first serious 1B baseline
Files: bitdiffusion/diffusion.py, bitdiffusion/train.py, bitdiffusion/sample.py
Why
This is the weakest-supervised subsystem in the code.
The code explicitly excludes thinking positions from direct supervised loss and expects them to become useful only through downstream answer gradients. That is possible in principle, but it is a weak signal. Also, training prepends one think prefix to the whole sequence, while the block sampler can prepend think tokens before every block. That is a train-test mismatch.
Patch
For the baseline 1B run:
- set
N_think = 0 - set
think_prob = 0 - keep the code, but remove it from the main experiment
Then add it back only after the baseline works.
Why it matters
Your core concept does not need thinking tokens to be valid. MDLM, block diffusion, ternary weights, and hybrid A8/A4 already make a complete research story. Thinking tokens add ambiguity without adding much confidence. (arXiv)
7. Keep the current KV cache labeled as a prototype, and do not overfit conclusions to it
Files: bitdiffusion/quantization.py, bitdiffusion/sample.py
What is happening
Your active cache path uses a simple per-head absmax scheme for both keys and values. That is fine for a prototype, but it is simpler than the best-supported KV-cache quantization approaches.
KIVI’s main result is that keys and values want different treatment: keys per-channel, values per-token. Your current path does not do that. (arXiv)
Patch
Do one of these:
- leave the current cache as-is, but call it a prototype cache and benchmark it honestly
- or implement asymmetric K/V quantization closer to KIVI
My recommendation
For the first 1B run, keep it simple and prototype-level. Do not burn time rewriting the cache before the base model is proven.
Why it matters
KV cache is mostly an inference feature in your code, not a training feature. So this is not a blocker for pretraining. It is a blocker for making strong claims about deployment efficiency or quality retention. KIVI shows the asymmetry matters. (arXiv)
8. Add one explicit ablation checkpoint before the A8 → A4 switch
Files: bitdiffusion/train.py
Patch
Save:
- one checkpoint right before the activation-mode switch
- one checkpoint shortly after entering A4 mode
Also log:
- masked-token accuracy
- answer-only loss
- fraction of masked positions per batch
- gradient norm
- activation mode
Why it matters
Current low-bit attention work says 4-bit attention is the main obstacle because of heavy-tailed activations and precision mismatch. If your run degrades, you want to know whether the break started:
- before A4,
- exactly at A4,
- or long after. (arXiv)
Tier 2: worth fixing, but not before the first scaled run
9. Separate the “full denoiser” and “block sampler” evaluation stories
Files: bitdiffusion/sample.py
Why
Your full denoiser resets the KV cache every denoising step, which is logically correct because the full masked pattern changes every step. That means KV cache does not buy much there. The real cache benefit is in the block sampler.
Patch
Report them separately:
- full diffusion sampling quality
- blockwise generation quality and speed
- KV cache effect only inside the blockwise path
Why it matters
It makes your conclusions cleaner and more aligned with what block diffusion is actually buying. (arXiv)
10. Add smoke tests for the exact failure cases above
Files: tests/
Add these tests
MaskDiffusionLosszero-supervision returns finite zerogenerate_sample()never samples special IDs into normal tokensBlockwiseDiffusionSampler.generate(num_samples>1)returns independent per-sample outputs- data loader preserves intended length behavior
- one tiny forward/backward pass on CPU works
Why it matters
Your code is already close enough to useful that small regressions matter. At 1B scale, simple tests are much cheaper than one wasted launch.
Tier 3: optional improvements after the baseline works
11. If you want better KV behavior, move toward asymmetric quantization
This is where I would spend time after the base 1B model works. KIVI gives a strong hint that asymmetry between keys and values is not cosmetic. (arXiv)
12. If you want stronger A4 confidence, add more attention-specific diagnostics
Attn-QAT makes it very clear that the hard part is not generic quantization. It is attention numerics. That suggests logging:
- attention score range
- softmax entropy
- per-head activation max
- fraction of saturated quantized values during A4 mode (arXiv)
13. If you want thinking tokens back, add a real training signal
Do this only after the plain model works. Right now they are more of a research hypothesis than a dependable subsystem.
The patch order I would actually execute
This week
- Fix
MaskDiffusionLossNaN case. - Fix block sampler multi-sample bug.
- Fix
generate_sample()vocabulary slicing. - Decide fixed-length vs variable-length training and simplify accordingly.
Before the 1B launch
- Disable thinking tokens for baseline.
- Add mask-aware loss/corruption if you keep any variable-length batching.
- Add checkpoints around the A4 transition.
- Add the small tests.
After the baseline run
- Improve KV asymmetry.
- Add attention-specific A4 diagnostics.
- Reintroduce thinking tokens only as an ablation.
How I think it works after this patch plan
If you apply Tier 0 and Tier 1, I think the model has a real chance to do what you want in the limited sense that matters first:
- train a 1B masked-diffusion model,
- keep the ternary-weight core,
- switch into hybrid low-bit activation late,
- produce coherent blockwise-generated text,
- and give you a trustworthy baseline for later KV and thinking-token experiments.
That expectation is supported by the literature around MDLM, block diffusion, BitNet b1.58, and BitNet a4.8. (arXiv)
If you do not apply the top patches, I think the likely failure mode is not dramatic collapse. It is worse: an expensive run that “sort of works,” but leaves you unsure whether the weak points came from diffusion, low-bit attention, your data path, or the under-supervised thinking-token mechanism. Recent 4-bit attention results are exactly why that distinction matters. (arXiv)
Discussion in the ATmosphere