Custom batches in sentence-transformers for MultipleNegativesRankingLoss
I am using the sentence-transformers library to finetune a model to generate embeddings for postal addresses so that embeddings for the same address written in different manners are close to each other.
However, addresses that only differ for a small part (e.g. the street number, or the name of the city) must have sufficiently different embeddings, which is not the case when I try to finetune the all-mpnet-base-v2 model using the CosineSimilarityLoss (or similars).
Therefore, I am trying to use the MultipleNegativesRankingLoss. As far as I understand, the computation of this loss function takes into account the whole minibatch, not just the individual pairs of sentences/addresses. It enforces not only that sentences/addresses in a given pair have similar embeddings, but also consider sentences/addresses in different pairs of the same batch as negatives (which is exactly what I need).
Therefore, I prepared a trainining set that is already partitioned in batches with 256 pairs each, taking care to put in the same batch pairs that must be considered strong negatives even if they are quite similar.
batches: list[tuple[tuple[str, str], 256]] = [
(
(batch1_anchor1, batch1_positive1), # ('Blue Street, 1, New York', 'Blue Street 1 - New York'),
(batch1_anchor2, batch1_positive2), # ('Blue Street, 11, New York', 'Blue Street 11 - New York'),
(batch1_anchor3, batch1_positive3),
...
),
(
(batch2_anchor1, batch2_positive1),
(batch2_anchor2, batch2_positive2),
(batch2_anchor3, batch2_positive3),
...
),
....
]
My question is: how do I preserve this batch structure when loading the training data into the trainer? The SentenceTransformerTrainer class only accepts a `datasets.Dataset, I see no way to preserve my batches.
loss_fn = MultipleNegativesRankingLoss(
model,
directions=(query_to_doc', 'query_to_query', 'doc_to_query', 'doc_to_doc')
)
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=???, # here I can pass a datasets.Dataset, not a torch.utils.data.DataLoader or equivalent
loss=loss_fn,
)
Discussion in the ATmosphere