External Publication
Visit Post

Gemma4-e4b adaptors fuse after training , how?

Hugging Face Forums [Unofficial] May 31, 2026
Source

i manage to do it with qwen3.5, Gemma-4-e2b but after many attempts AI assisted i don’t get anything but errors, more than one million tokens with deep seek, many scripts, even retrained the model with mlx-vlm and mlx-lm . My las adapters where produced with this: cmd = [ VENV, “-m”, “mlx_vlm.lora”, “–model-path”, MODEL, “–dataset”, DATA, “–batch-size”, “1”, “–iters”, “300”, “–learning-rate”, “1e-5”, “–grad-checkpoint”, “–gradient-accumulation-steps”, “4”, “–steps-per-save”, “50”, “–output-path”, OUTPUT, “–lora-rank”, “8”, “–lora-alpha”, “16”, “–train-on-completions”, “–assistant-id”, “4368”, ]

and this is the fusion script : #!/usr/bin/env python3 “”“Fusion LoRA adapters into Gemma 4 E4B 8-bit model. Ejecutar DIRECTAMENTE en el Mini, no por SSH. “”” import json, shutil from pathlib import Path import mlx.core as mx import mlx.nn as nn from mlx_vlm import load from mlx_vlm.trainer.lora import LoRaLayer

BASE = Path(“/Users/hal9000/Desktop/AI/modelos/gemma_4_e4b_it_8bit”) ADAPTER = Path(“/Volumes/ssd./ssd_gemma4/adaptadores_v2”) SALIDA = Path(“/Users/hal9000/Desktop/AI/modelos/gemma_4_e4b_it_8bit_ssd_fused”)

adapter_config.json

with open(ADAPTER / “adapter_config.json”, “w”) as f: json.dump({“rank”: 8, “alpha”: 16.0, “dropout”: 0.0}, f)

print(“Cargando modelo con adapters…”) model, processor = load(str(BASE), adapter_path=str(ADAPTER))

print(“Fusionando capas LoRA…”) to_update = {} for name, module in model.named_modules(): if not isinstance(module, LoRaLayer): continue orig = module.original_layer if isinstance(orig, nn.QuantizedLinear): w = mx.dequantize(orig.weight, orig.scales, orig.biases, orig.group_size, orig.bits) else: w = orig.weight lu = module.scale * (module.A @ module.B) od, id_ = w.shape if lu.shape == (id_, od): wf = w + lu.T elif lu.shape == (od, id_): wf = w + lu else: print(" SKIP shape mismatch:“, name, w.shape, lu.shape) continue to_update[name] = wf print(” %d capas fusionadas" % len(to_update))

print(“Cargando pesos originales…”) all_w = {} for sf in sorted(BASE.glob(“model-*.safetensors”)): all_w.update(mx.load(str(sf))) print(" %d tensores totales" % len(all_w))

Reemplazar pesos fusionados

for name, wf in to_update.items(): mk = name + “.weight” if mk in all_w: all_w[mk] = wf for s in [“.scales”, “.biases”]: all_w.pop(name + s, None)

VERIFICAR shapes conv contra originales y restaurar si cambiaron

print(“Verificando shapes conv…”) restored = 0 for sf in sorted(BASE.glob(“model-*.safetensors”)): orig = mx.load(str(sf)) for k, v in orig.items(): if len(v.shape) == 4 and k in all_w and all_w[k].shape != v.shape: print(" RESTAURANDO %s: %s → %s" % (k, list(all_w[k].shape), list(v.shape))) all_w[k] = v restored += 1 print(" %d conv weights restaurados" % restored)

Guardar

if SALIDA.exists(): shutil.rmtree(SALIDA) SALIDA.mkdir(parents=True)

print(“Guardando %d tensores…” % len(all_w)) mx.save_safetensors(str(SALIDA / “model.safetensors”), all_w)

Copiar configs

for f in BASE.glob(“*”): if f.suffix in (“.json”, “.txt”, “.md”) or “tokenizer” in f.name: shutil.copy2(f, SALIDA)

Eliminar index.json

idx = SALIDA / “model.safetensors.index.json” if idx.exists(): idx.unlink()

print("OK " + str(SALIDA))

Verificacion final

print(“Verificando conv weights en output…”) d = mx.load(str(SALIDA / “model.safetensors”)) ok = True for k, v in d.items(): if “conv” in k.lower() and len(v.shape) == 4: print(" %s: %s" % (k, list(v.shape)))

Verificar contra original

for sf in sorted(BASE.glob(“model-*.safetensors”)): orig = mx.load(str(sf)) if k in orig and v.shape != orig[k].shape: print(" ERROR: deberia ser %s" % list(orig[k].shape)) ok = False if ok: print(“TODOS los conv weights tienen shapes correctos!”)

Discussion in the ATmosphere

Loading comments...