{
  "$type": "site.standard.document",
  "bskyPostRef": {
    "cid": "bafyreibic5onyrx6vpnla27cuvbxjw43gsxrjec2x7zbxhtglcwv4n6jt4",
    "uri": "at://did:plc:pgryn3ephfd2xgft23qokfzt/app.bsky.feed.post/3mkmggxo4gn42"
  },
  "path": "/t/custom-batches-in-sentence-transformers-for-multiplenegativesrankingloss/175587#post_2",
  "publishedAt": "2026-04-29T05:08:08.000Z",
  "site": "https://discuss.huggingface.co",
  "tags": [
    "@tomaarsen",
    "datasets.Dataset",
    "SentenceTransformerTrainer",
    "SentenceTransformerTrainingArguments(batch_sampler=...)",
    "MultipleNegativesRankingLoss",
    "CosineSimilarityLoss",
    "SentenceTransformerTrainingArguments.batch_sampler",
    "samplers reference",
    "DefaultBatchSampler",
    "training overview",
    "CachedMultipleNegativesRankingLoss"
  ],
  "textContent": "Hmm… Apparently, it’s possible using the custom batch sampler… but I’m not entirely sure if that’s true.\njust in case, @tomaarsen\n\n* * *\n\n# Custom batches in `sentence-transformers` for `MultipleNegativesRankingLoss`\n\n## Short answer\n\nUse a normal datasets.Dataset, but do **not** rely on the default batching behavior of SentenceTransformerTrainer.\n\nInstead:\n\n  1. Flatten your pre-built batches into a single `datasets.Dataset`.\n  2. Keep the rows ordered so that rows `0..255` are your first curated batch, rows `256..511` are your second curated batch, and so on.\n  3. Pass a custom batch sampler through SentenceTransformerTrainingArguments(batch_sampler=...).\n  4. Make the custom sampler yield exactly the row-index groups you want.\n\n\n\nThe key point is:\n\n\n    train_dataset = storage format\n    batch_sampler = batching policy\n    loss = uses the resulting batch as the contrastive pool\n\n\nSo your `train_dataset=???` should be a flattened `datasets.Dataset`, and the batch-preserving logic should live in `args.batch_sampler`.\n\n* * *\n\n## Why this is the right abstraction\n\nYour understanding of MultipleNegativesRankingLoss is correct.\n\n`MultipleNegativesRankingLoss` is an **in-batch contrastive loss**. For each anchor in a batch, the matching positive should be closer than the other candidate positives or documents in the same batch.\n\nFor example, suppose one minibatch contains:\n\n\n    [\n        (\"Blue Street, 1, New York\",  \"Blue Street 1 - New York\"),\n        (\"Blue Street, 11, New York\", \"Blue Street 11 - New York\"),\n    ]\n\n\nThen, for the first row, the model is trained to make:\n\n\n    \"Blue Street, 1, New York\"\n\n\ncloser to:\n\n\n    \"Blue Street 1 - New York\"\n\n\nthan to:\n\n\n    \"Blue Street 11 - New York\"\n\n\nThat is exactly the behavior you want.\n\nWith address matching, the hard part is not only learning that two variants of the same address are close. The hard part is learning that extremely similar-looking addresses may still be different real-world entities:\n\n\n    Blue Street 1, New York       ==  Blue St. 1, NYC\n    Blue Street 1, New York       !=  Blue Street 11, New York\n    Blue Street 1, New York       !=  Blue Street 1, Newark\n    Blue Street 1 Apt 2           !=  Blue Street 1 Apt 20\n\n\nSo your curated batch structure is not incidental. It is part of the supervision.\n\n* * *\n\n## Why `CosineSimilarityLoss` is usually weaker here\n\nCosineSimilarityLoss is pairwise. It sees one pair and a target score.\n\nThat works well when your labels are naturally pairwise:\n\n\n    pair A similarity = 0.95\n    pair B similarity = 0.10\n    pair C similarity = 0.60\n\n\nBut your real task is closer to ranking:\n\n\n    Given:\n        \"Blue Street, 1, New York\"\n\n    Rank this highest:\n        \"Blue Street 1 - New York\"\n\n    Rank these lower:\n        \"Blue Street 11 - New York\"\n        \"Blue Street 1 - Newark\"\n        \"Blue Avenue 1 - New York\"\n        \"Blue Street 1 Apt 2 - New York\"\n\n\nThat is why MultipleNegativesRankingLoss is a better fit. It turns each minibatch into a local retrieval problem.\n\nIn this sense, your task is not just sentence similarity. It is closer to:\n\n\n    postal-address entity resolution\n    +\n    dense retrieval\n    +\n    hard-negative metric learning\n\n\n* * *\n\n## Why not pass a `DataLoader`?\n\nThe newer SentenceTransformerTrainer API expects a dataset, not a user-supplied PyTorch `DataLoader`.\n\nThat does not mean you cannot control batches.\n\nThe control point is SentenceTransformerTrainingArguments.batch_sampler, documented in the samplers reference. The sampler docs explain that a custom batch sampler can be supplied by subclassing DefaultBatchSampler or by passing a function that returns a `DefaultBatchSampler` instance.\n\nSo the correct structure is:\n\n\n    trainer = SentenceTransformerTrainer(\n        model=model,\n        args=args,                 # contains batch_sampler\n        train_dataset=train_dataset,\n        loss=loss_fn,\n    )\n\n\nnot:\n\n\n    trainer = SentenceTransformerTrainer(\n        model=model,\n        train_dataloader=my_dataloader,  # not the intended API\n    )\n\n\n* * *\n\n## Step 1: flatten your pre-built batches\n\nAssume your current data looks conceptually like this:\n\n\n    batches = [\n        [\n            (batch1_anchor1, batch1_positive1),\n            (batch1_anchor2, batch1_positive2),\n            ...\n        ],\n        [\n            (batch2_anchor1, batch2_positive1),\n            (batch2_anchor2, batch2_positive2),\n            ...\n        ],\n    ]\n\n\nFlatten it while preserving order:\n\n\n    from datasets import Dataset\n\n    BATCH_SIZE = 256\n\n    flat_anchors: list[str] = []\n    flat_positives: list[str] = []\n\n    for batch_idx, batch in enumerate(batches):\n        if len(batch) != BATCH_SIZE:\n            raise ValueError(\n                f\"Batch {batch_idx} has {len(batch)} pairs, expected {BATCH_SIZE}.\"\n            )\n\n        for anchor, positive in batch:\n            flat_anchors.append(anchor)\n            flat_positives.append(positive)\n\n    train_dataset = Dataset.from_dict(\n        {\n            \"anchor\": flat_anchors,\n            \"positive\": flat_positives,\n        }\n    )\n\n    # Keep column order explicit.\n    train_dataset = train_dataset.select_columns([\"anchor\", \"positive\"])\n\n\nNow the dataset rows have this structure:\n\n\n    rows 0..255      = curated batch 0\n    rows 256..511    = curated batch 1\n    rows 512..767    = curated batch 2\n    ...\n\n\nThe dataset itself is flat, but your precomputed batch structure is preserved by row position.\n\n* * *\n\n## Important: do not pass metadata columns directly to the trainer\n\nSentence Transformers training datasets are column-order-sensitive. In the training overview, non-label columns are treated as model inputs.\n\nSo this is safe:\n\n\n    train_dataset = train_dataset.select_columns([\"anchor\", \"positive\"])\n\n\nThis is risky if passed directly to the trainer:\n\n\n    Dataset.from_dict(\n        {\n            \"anchor\": anchors,\n            \"positive\": positives,\n            \"batch_id\": batch_ids,\n            \"canonical_address_id\": canonical_ids,\n        }\n    )\n\n\nbecause `batch_id` and `canonical_address_id` are metadata, not text inputs.\n\nKeep metadata during preprocessing and validation, but remove it before training:\n\n\n    train_dataset = full_dataset.select_columns([\"anchor\", \"positive\"])\n\n\n* * *\n\n## Step 2: define a custom batch sampler\n\nThis sampler yields contiguous blocks of indices:\n\n\n    [0, 1, 2, ..., 255]\n    [256, 257, 258, ..., 511]\n    [512, 513, 514, ..., 767]\n    ...\n\n\nIt can shuffle the **order of whole batches** between epochs, but it never mixes examples across your curated batches.\n\n\n    from collections.abc import Iterator\n\n    import torch\n    from datasets import Dataset\n    from sentence_transformers.sampler import DefaultBatchSampler\n\n\n    class ExactPreBatchedSampler(DefaultBatchSampler):\n        \"\"\"\n        Preserves precomputed contiguous batches.\n\n        Assumption:\n            Rows 0..255      are curated batch 0\n            Rows 256..511    are curated batch 1\n            Rows 512..767    are curated batch 2\n            ...\n\n        The sampler may shuffle the order of whole batches, but it never mixes\n        rows from different precomputed batches.\n        \"\"\"\n\n        def __init__(\n            self,\n            dataset: Dataset,\n            batch_size: int,\n            drop_last: bool,\n            valid_label_columns: list[str] | None = None,\n            generator: torch.Generator | None = None,\n            seed: int = 0,\n            shuffle_batches: bool = True,\n        ) -> None:\n            super().__init__(\n                dataset=dataset,\n                batch_size=batch_size,\n                drop_last=drop_last,\n                valid_label_columns=valid_label_columns,\n                generator=generator,\n                seed=seed,\n            )\n            self.dataset = dataset\n            self.shuffle_batches = shuffle_batches\n\n            if self.batch_size <= 0:\n                raise ValueError(f\"batch_size must be positive, got {self.batch_size}.\")\n\n            if len(self.dataset) < self.batch_size:\n                raise ValueError(\n                    f\"Dataset has {len(self.dataset)} rows, \"\n                    f\"but batch_size={self.batch_size}.\"\n                )\n\n        def __iter__(self) -> Iterator[list[int]]:\n            # DefaultBatchSampler provides epoch handling via SetEpochMixin.\n            if self.generator is not None and self.seed is not None:\n                self.generator.manual_seed(self.seed + self.epoch)\n\n            n_full_batches = len(self.dataset) // self.batch_size\n            remainder_start = n_full_batches * self.batch_size\n\n            batch_ids = torch.arange(n_full_batches)\n\n            if self.shuffle_batches:\n                batch_ids = batch_ids[\n                    torch.randperm(n_full_batches, generator=self.generator)\n                ]\n\n            for batch_id in batch_ids.tolist():\n                start = batch_id * self.batch_size\n                end = start + self.batch_size\n                yield list(range(start, end))\n\n            if not self.drop_last and remainder_start < len(self.dataset):\n                yield list(range(remainder_start, len(self.dataset)))\n\n        def __len__(self) -> int:\n            n_full_batches = len(self.dataset) // self.batch_size\n            has_remainder = len(self.dataset) % self.batch_size != 0\n            return n_full_batches + int(has_remainder and not self.drop_last)\n\n\n* * *\n\n## Step 3: pass the sampler through `SentenceTransformerTrainingArguments`\n\nUse a small factory function. This is convenient because the trainer constructs the sampler internally and supplies arguments such as `dataset`, `batch_size`, `drop_last`, `generator`, and `seed`.\n\n\n    def exact_prebatched_sampler_factory(\n        dataset: Dataset,\n        batch_size: int,\n        drop_last: bool,\n        valid_label_columns: list[str] | None = None,\n        generator: torch.Generator | None = None,\n        seed: int = 0,\n    ):\n        if batch_size != BATCH_SIZE:\n            raise ValueError(\n                f\"Expected batch_size={BATCH_SIZE}, got {batch_size}. \"\n                \"Use per_device_train_batch_size=256.\"\n            )\n\n        return ExactPreBatchedSampler(\n            dataset=dataset,\n            batch_size=batch_size,\n            drop_last=drop_last,\n            valid_label_columns=valid_label_columns,\n            generator=generator,\n            seed=seed,\n            shuffle_batches=True,\n        )\n\n\nThen configure the trainer:\n\n\n    from sentence_transformers import (\n        SentenceTransformer,\n        SentenceTransformerTrainer,\n        SentenceTransformerTrainingArguments,\n        losses,\n    )\n\n    BATCH_SIZE = 256\n\n    model = SentenceTransformer(\"sentence-transformers/all-mpnet-base-v2\")\n\n    loss_fn = losses.MultipleNegativesRankingLoss(\n        model,\n        directions=(\n            \"query_to_doc\",\n            \"query_to_query\",\n            \"doc_to_query\",\n            \"doc_to_doc\",\n        ),\n        partition_mode=\"joint\",\n    )\n\n    args = SentenceTransformerTrainingArguments(\n        output_dir=\"models/address-mpnet-mnrl\",\n\n        # Must match your curated batch size.\n        per_device_train_batch_size=BATCH_SIZE,\n\n        # Usually safest if all curated batches are exactly size 256.\n        dataloader_drop_last=True,\n\n        # Critical part: preserve your precomputed batches.\n        batch_sampler=exact_prebatched_sampler_factory,\n\n        # Usual training settings. Tune these for your dataset.\n        num_train_epochs=1,\n        learning_rate=2e-5,\n        warmup_ratio=0.1,\n\n        # Use bf16 if your hardware supports it. Otherwise use fp16 or fp32.\n        bf16=True,\n        fp16=False,\n\n        logging_steps=50,\n        save_steps=500,\n        save_total_limit=2,\n    )\n\n    trainer = SentenceTransformerTrainer(\n        model=model,\n        args=args,\n        train_dataset=train_dataset,\n        loss=loss_fn,\n    )\n\n    trainer.train()\n\n\nThat is the core solution.\n\n* * *\n\n## Complete minimal example\n\n\n    from collections.abc import Iterator\n\n    import torch\n    from datasets import Dataset\n    from sentence_transformers import (\n        SentenceTransformer,\n        SentenceTransformerTrainer,\n        SentenceTransformerTrainingArguments,\n        losses,\n    )\n    from sentence_transformers.sampler import DefaultBatchSampler\n\n\n    BATCH_SIZE = 256\n\n\n    class ExactPreBatchedSampler(DefaultBatchSampler):\n        def __init__(\n            self,\n            dataset: Dataset,\n            batch_size: int,\n            drop_last: bool,\n            valid_label_columns: list[str] | None = None,\n            generator: torch.Generator | None = None,\n            seed: int = 0,\n            shuffle_batches: bool = True,\n        ) -> None:\n            super().__init__(\n                dataset=dataset,\n                batch_size=batch_size,\n                drop_last=drop_last,\n                valid_label_columns=valid_label_columns,\n                generator=generator,\n                seed=seed,\n            )\n            self.dataset = dataset\n            self.shuffle_batches = shuffle_batches\n\n        def __iter__(self) -> Iterator[list[int]]:\n            if self.generator is not None and self.seed is not None:\n                self.generator.manual_seed(self.seed + self.epoch)\n\n            n_full_batches = len(self.dataset) // self.batch_size\n            remainder_start = n_full_batches * self.batch_size\n\n            batch_ids = torch.arange(n_full_batches)\n\n            if self.shuffle_batches:\n                batch_ids = batch_ids[\n                    torch.randperm(n_full_batches, generator=self.generator)\n                ]\n\n            for batch_id in batch_ids.tolist():\n                start = batch_id * self.batch_size\n                end = start + self.batch_size\n                yield list(range(start, end))\n\n            if not self.drop_last and remainder_start < len(self.dataset):\n                yield list(range(remainder_start, len(self.dataset)))\n\n        def __len__(self) -> int:\n            n_full_batches = len(self.dataset) // self.batch_size\n            has_remainder = len(self.dataset) % self.batch_size != 0\n            return n_full_batches + int(has_remainder and not self.drop_last)\n\n\n    def exact_prebatched_sampler_factory(\n        dataset: Dataset,\n        batch_size: int,\n        drop_last: bool,\n        valid_label_columns: list[str] | None = None,\n        generator: torch.Generator | None = None,\n        seed: int = 0,\n    ):\n        if batch_size != BATCH_SIZE:\n            raise ValueError(\n                f\"Expected batch_size={BATCH_SIZE}, got {batch_size}. \"\n                \"Use per_device_train_batch_size=256.\"\n            )\n\n        return ExactPreBatchedSampler(\n            dataset=dataset,\n            batch_size=batch_size,\n            drop_last=drop_last,\n            valid_label_columns=valid_label_columns,\n            generator=generator,\n            seed=seed,\n            shuffle_batches=True,\n        )\n\n\n    flat_anchors: list[str] = []\n    flat_positives: list[str] = []\n\n    for batch_idx, batch in enumerate(batches):\n        if len(batch) != BATCH_SIZE:\n            raise ValueError(\n                f\"Batch {batch_idx} has {len(batch)} pairs, expected {BATCH_SIZE}.\"\n            )\n\n        for anchor, positive in batch:\n            flat_anchors.append(anchor)\n            flat_positives.append(positive)\n\n    train_dataset = Dataset.from_dict(\n        {\n            \"anchor\": flat_anchors,\n            \"positive\": flat_positives,\n        }\n    ).select_columns([\"anchor\", \"positive\"])\n\n    model = SentenceTransformer(\"sentence-transformers/all-mpnet-base-v2\")\n\n    loss_fn = losses.MultipleNegativesRankingLoss(\n        model,\n        directions=(\n            \"query_to_doc\",\n            \"query_to_query\",\n            \"doc_to_query\",\n            \"doc_to_doc\",\n        ),\n        partition_mode=\"joint\",\n    )\n\n    args = SentenceTransformerTrainingArguments(\n        output_dir=\"models/address-mpnet-mnrl\",\n        per_device_train_batch_size=BATCH_SIZE,\n        dataloader_drop_last=True,\n        batch_sampler=exact_prebatched_sampler_factory,\n        num_train_epochs=1,\n        learning_rate=2e-5,\n        warmup_ratio=0.1,\n        bf16=True,\n        fp16=False,\n        logging_steps=50,\n        save_steps=500,\n        save_total_limit=2,\n    )\n\n    trainer = SentenceTransformerTrainer(\n        model=model,\n        args=args,\n        train_dataset=train_dataset,\n        loss=loss_fn,\n    )\n\n    trainer.train()\n\n\n* * *\n\n## Verify the sampler before training\n\nBefore launching a long training job, inspect a few batches.\n\n\n    sampler = exact_prebatched_sampler_factory(\n        dataset=train_dataset,\n        batch_size=BATCH_SIZE,\n        drop_last=True,\n        generator=torch.Generator().manual_seed(42),\n        seed=42,\n    )\n\n    for batch_number, indices in zip(range(5), sampler):\n        print(\n            \"batch_number:\",\n            batch_number,\n            \"first_index:\",\n            indices[0],\n            \"last_index:\",\n            indices[-1],\n            \"size:\",\n            len(indices),\n        )\n\n        print(\"first anchor:\", train_dataset[indices[0]][\"anchor\"])\n        print(\"first positive:\", train_dataset[indices[0]][\"positive\"])\n        print()\n\n\nExpected shape:\n\n\n    batch_number: 0 first_index: 512 last_index: 767 size: 256\n    batch_number: 1 first_index: 0   last_index: 255 size: 256\n    batch_number: 2 first_index: 256 last_index: 511 size: 256\n\n\nThe order may differ because whole batches are shuffled, but each yielded batch should still be a contiguous block of 256 rows.\n\n* * *\n\n## Address-specific warning: false negatives\n\nThe main risk with `MultipleNegativesRankingLoss` is **false negatives**.\n\nIn MNRL, other positives in the same batch are treated as negatives for the current anchor. That is useful only if those other positives are truly different addresses.\n\nThis is good:\n\n\n    anchor:\n        Blue Street, 1, New York\n\n    positive:\n        Blue Street 1 - New York\n\n    in-batch negative:\n        Blue Street 11 - New York\n\n\nThis is dangerous:\n\n\n    anchor:\n        Blue Street, 1, New York\n\n    positive:\n        Blue Street 1 - New York\n\n    in-batch negative from another row:\n        1 Blue St., NYC\n\n\nbecause `1 Blue St., NYC` may be the same real address.\n\nSo your batch builder should enforce a rule like:\n\n\n    No two different rows in the same MNRL batch may refer to the same canonical address.\n\n\nExact string deduplication is not enough. Prefer deduplication by one or more of:\n\n\n    canonical address ID\n    delivery point ID\n    authoritative normalized address\n    geocoder result ID\n    parcel/building/unit ID\n    high-confidence rooftop coordinate\n\n\n* * *\n\n## Be careful with all four `directions`\n\nYou used:\n\n\n    directions=(\n        \"query_to_doc\",\n        \"query_to_query\",\n        \"doc_to_query\",\n        \"doc_to_doc\",\n    )\n\n\nThat is supported by the current MultipleNegativesRankingLoss API and can provide a stronger signal.\n\nHowever, it also makes batch cleanliness stricter.\n\nWith only `query_to_doc`, the main requirement is:\n\n\n    anchor_i should not match positive_j for i != j\n\n\nWith `query_to_query`, you also need:\n\n\n    anchor_i should not be equivalent to anchor_j\n\n\nWith `doc_to_doc`, you also need:\n\n\n    positive_i should not be equivalent to positive_j\n\n\nFor all four directions, validate:\n\n\n    anchor_i is not equivalent to anchor_j\n    positive_i is not equivalent to positive_j\n    anchor_i is not equivalent to positive_j for i != j\n    positive_i is not equivalent to anchor_j for i != j\n\n\nIf your canonical-address checks are not strong yet, consider starting with a simpler loss configuration:\n\n\n    loss_fn = losses.MultipleNegativesRankingLoss(\n        model,\n        directions=(\"query_to_doc\", \"doc_to_query\"),\n        partition_mode=\"per_direction\",\n    )\n\n\nThen compare against the all-four-direction version.\n\n* * *\n\n## Consider explicit hard negatives\n\nYour current format is:\n\n\n    (anchor, positive)\n\n\nThat is valid for MNRL.\n\nBut if you already know specific hard negatives, you can also use:\n\n\n    (anchor, positive, negative_1, negative_2, negative_3)\n\n\n`MultipleNegativesRankingLoss` supports pairs, triplets, and n-tuples:\n\n\n    train_dataset = Dataset.from_dict(\n        {\n            \"anchor\": [\n                \"Blue Street, 1, New York\",\n                \"Blue Street, 11, New York\",\n            ],\n            \"positive\": [\n                \"Blue Street 1 - New York\",\n                \"Blue Street 11 - New York\",\n            ],\n            \"negative_1\": [\n                \"Blue Street 11 - New York\",\n                \"Blue Street 1 - New York\",\n            ],\n            \"negative_2\": [\n                \"Blue Street 1 - Newark\",\n                \"Blue Street 11 - Newark\",\n            ],\n        }\n    ).select_columns([\"anchor\", \"positive\", \"negative_1\", \"negative_2\"])\n\n\nThis gives the loss both:\n\n\n    explicit hard negatives attached to each row\n    +\n    in-batch hard negatives created by your curated batch\n\n\nFor a first implementation, I would start with `(anchor, positive)` and curated batches. After that works, add explicit hard-negative columns and compare.\n\n* * *\n\n## Do not use gradient accumulation as a substitute for batch size\n\nThis is a common mistake.\n\nFor MNRL, the important thing is the **contrastive batch size** : the number of examples visible to the loss at the same time.\n\nThese are not equivalent:\n\n\n    per_device_train_batch_size = 32\n    gradient_accumulation_steps = 8\n\n\nand:\n\n\n    per_device_train_batch_size = 256\n\n\nThe first setup may update the optimizer after 256 examples, but each MNRL softmax only sees 32 examples at once.\n\nIf batch size 256 does not fit in GPU memory, use CachedMultipleNegativesRankingLoss:\n\n\n    loss_fn = losses.CachedMultipleNegativesRankingLoss(\n        model,\n        mini_batch_size=32,\n        directions=(\n            \"query_to_doc\",\n            \"query_to_query\",\n            \"doc_to_query\",\n            \"doc_to_doc\",\n        ),\n        partition_mode=\"joint\",\n    )\n\n\nKeep:\n\n\n    per_device_train_batch_size = 256\n\n\nThe distinction is:\n\n\n    per_device_train_batch_size = contrastive batch size\n    mini_batch_size             = internal memory chunk size\n\n\nThat matters because your curated group of 256 addresses is the semantic training unit.\n\n* * *\n\n## Multi-GPU caution\n\n`MultipleNegativesRankingLoss` has a `gather_across_devices` option. It can increase the effective negative pool across devices, but it also changes the effective contrastive batch.\n\nFor your case, exact batch composition matters. I would first validate everything on one GPU:\n\n\n    single GPU\n    per_device_train_batch_size = 256\n    gather_across_devices = False\n\n\nThen move to distributed training only after you have verified what examples are actually visible to each loss computation.\n\n* * *\n\n## Practical address-data advice\n\n### Good positives\n\nUse formatting variants of the same real address:\n\n\n    Blue Street, 1, New York\n    Blue Street 1 - New York\n\n    1 Blue St., NYC\n    Blue Street 1, New York, NY\n\n    Apt 2, 1 Blue Street, New York\n    1 Blue St Apartment 2, NYC\n\n\n### Good hard negatives\n\nUse addresses that differ in identity-critical components:\n\n\n    same street + same city + different house number\n    same street + same house number + different city\n    same building + different apartment/unit\n    same street + different postal code\n    same house number + similar street name\n    same city + changed street suffix\n    changed directional: North Main St vs South Main St\n\n\nExamples:\n\n\n    Blue Street 1, New York\n    Blue Street 11, New York\n\n    Blue Street 1, New York\n    Blue Street 1, Newark\n\n    Blue Street 1 Apt 2, New York\n    Blue Street 1 Apt 20, New York\n\n    North Main Street 10\n    South Main Street 10\n\n\n### Good batch design\n\nBuild each batch as a “confusion neighborhood”:\n\n\n    Batch theme:\n        same normalized street + same city\n\n    Rows:\n        Blue Street 1, New York        ↔ 1 Blue St, NYC\n        Blue Street 11, New York       ↔ 11 Blue St, NYC\n        Blue Street 1 Apt 2, New York  ↔ Apt 2, 1 Blue St, NYC\n        Blue Street 1 Apt 20, New York ↔ Apt 20, 1 Blue St, NYC\n        Blue Avenue 1, New York        ↔ 1 Blue Ave, NYC\n        Blue Street 1, Newark          ↔ 1 Blue St, Newark\n\n\nThis is much better than random batching, because random negatives are often too easy.\n\n* * *\n\n## Recommended preprocessing validation\n\nKeep metadata before training:\n\n\n    training_rows = [\n        {\n            \"batch_id\": 0,\n            \"anchor\": \"Blue Street, 1, New York\",\n            \"positive\": \"Blue Street 1 - New York\",\n            \"canonical_address_id\": \"addr_001\",\n        },\n        {\n            \"batch_id\": 0,\n            \"anchor\": \"Blue Street, 11, New York\",\n            \"positive\": \"Blue Street 11 - New York\",\n            \"canonical_address_id\": \"addr_002\",\n        },\n    ]\n\n\nValidate each batch:\n\n\n    def validate_precomputed_batches(batches_with_ids: list[list[dict]]) -> None:\n        for batch_idx, batch in enumerate(batches_with_ids):\n            if len(batch) != BATCH_SIZE:\n                raise ValueError(\n                    f\"Batch {batch_idx} has {len(batch)} rows, expected {BATCH_SIZE}.\"\n                )\n\n            canonical_ids = [row[\"canonical_address_id\"] for row in batch]\n\n            if len(canonical_ids) != len(set(canonical_ids)):\n                raise ValueError(\n                    f\"Batch {batch_idx} contains duplicate canonical address IDs. \"\n                    \"This creates false negatives for MNRL.\"\n                )\n\n            for row_idx, row in enumerate(batch):\n                if not row[\"anchor\"] or not row[\"positive\"]:\n                    raise ValueError(\n                        f\"Batch {batch_idx}, row {row_idx} has empty text.\"\n                    )\n\n\nThen build the trainer dataset with only text columns:\n\n\n    train_dataset = Dataset.from_dict(\n        {\n            \"anchor\": [row[\"anchor\"] for batch in batches_with_ids for row in batch],\n            \"positive\": [row[\"positive\"] for batch in batches_with_ids for row in batch],\n        }\n    ).select_columns([\"anchor\", \"positive\"])\n\n\nThis gives you metadata safety during preprocessing and a clean text-only dataset during training.\n\n* * *\n\n## Evaluation: do not rely only on average cosine similarity\n\nFor address embeddings, generic similarity evaluation is not enough.\n\nUse at least these four evaluation types.\n\n### 1. Same-address invariance\n\nPairs that should be close:\n\n\n    Blue Street 1, New York\n    1 Blue St., NYC\n\n\nMeasure:\n\n\n    positive cosine distribution\n\n\n### 2. Hard-negative separation\n\nPairs that should not be too close:\n\n\n    Blue Street 1, New York\n    Blue Street 11, New York\n\n\nSlice by component type:\n\n\n    house number changed\n    city changed\n    unit changed\n    postal code changed\n    street suffix changed\n    directional changed\n\n\n### 3. Triplet accuracy\n\nTriplets:\n\n\n    anchor:\n        Blue Street 1, New York\n\n    positive:\n        1 Blue St., NYC\n\n    negative:\n        Blue Street 11, New York\n\n\nMeasure:\n\n\n    score(anchor, positive) > score(anchor, negative)\n\n\nAlso measure the margin:\n\n\n    score(anchor, positive) - score(anchor, negative)\n\n\n### 4. Retrieval\n\nCorpus:\n\n\n    all canonical addresses\n\n\nQuery:\n\n\n    raw or noisy address variant\n\n\nMeasure:\n\n\n    Recall@1\n    Recall@5\n    Recall@10\n    MRR\n    nDCG\n\n\nFor a production address resolver, retrieval metrics are usually the most realistic. The main question is whether the correct canonical address appears in the top candidates.\n\n* * *\n\n## Production design recommendation\n\nFor high-precision address matching, I would not rely only on one embedding model.\n\nUse a two-stage architecture:\n\n\n    raw address\n        ↓\n    bi-encoder embedding model\n        ↓\n    top 10 / top 50 canonical candidates\n        ↓\n    cross-encoder or structured verifier\n        ↓\n    same / different / uncertain\n\n\nThe bi-encoder is fast and good for candidate retrieval. A cross-encoder or verifier can compare two addresses jointly and pay close attention to exact differences:\n\n\n    1 vs 11\n    NYC vs Newark\n    Apt 2 vs Apt 20\n    Street vs Avenue\n    North Main vs South Main\n\n\nFor address matching, this second stage is often what prevents high-cost false positives.\n\n* * *\n\n## Practical experiment order\n\n### Experiment 1: curated batches with simple MNRL\n\n\n    loss_fn = losses.MultipleNegativesRankingLoss(\n        model,\n        directions=(\"query_to_doc\",),\n    )\n\n\nThis is the simplest baseline.\n\n### Experiment 2: symmetric MNRL\n\n\n    loss_fn = losses.MultipleNegativesRankingLoss(\n        model,\n        directions=(\"query_to_doc\", \"doc_to_query\"),\n        partition_mode=\"per_direction\",\n    )\n\n\nThis is a good next step because address equivalence is symmetric.\n\n### Experiment 3: all four directions\n\n\n    loss_fn = losses.MultipleNegativesRankingLoss(\n        model,\n        directions=(\n            \"query_to_doc\",\n            \"query_to_query\",\n            \"doc_to_query\",\n            \"doc_to_doc\",\n        ),\n        partition_mode=\"joint\",\n    )\n\n\nUse this only if your batch validation is strong and evaluation improves.\n\n### Experiment 4: explicit hard negatives\n\nUse columns like:\n\n\n    anchor\n    positive\n    negative_1\n    negative_2\n    negative_3\n\n\nStill preserve your curated batch structure.\n\n### Experiment 5: CachedMNRL\n\nIf memory is limiting, switch from MNRL to CachedMNRL while preserving `per_device_train_batch_size=256`.\n\n* * *\n\n## Common mistakes to avoid\n\n### Mistake 1: putting `batch_id` into the training dataset\n\nKeep metadata outside the trainer dataset unless you customize the collator/loss path.\n\n### Mistake 2: using random batches\n\nRandom negatives are mostly too easy. Your task needs hard negatives.\n\n### Mistake 3: trusting exact-string deduplication\n\nDifferent strings can still be the same address.\n\n### Mistake 4: using gradient accumulation to simulate batch size 256\n\nIt does not create an MNRL negative pool of 256.\n\n### Mistake 5: training all four directions without strong false-negative checks\n\n`query_to_query` and `doc_to_doc` make false negatives more damaging.\n\n* * *\n\n## Final answer\n\nCreate a normal flattened `datasets.Dataset` with only the text columns needed by the loss. Preserve your handcrafted batch structure by row order. Then pass a custom `DefaultBatchSampler` or sampler factory through `SentenceTransformerTrainingArguments(batch_sampler=...)` so the trainer’s internal dataloader yields exactly your curated groups of 256 pairs.\n\nIn short:\n\n\n    Flatten your batches.\n    Keep rows ordered by curated batch.\n    Use a custom batch sampler.\n    Set per_device_train_batch_size=256.\n    Validate that no batch contains two equivalent canonical addresses.\n    Use CachedMNRL if memory is too tight.\n    Evaluate with retrieval and hard-negative address tests.\n",
  "title": "Custom batches in sentence-transformers for MultipleNegativesRankingLoss"
}