Training

A pre-trained model is utilized (which is trained on a large general textual corpus) and then we fine-tune (continue training) on a specialized dataset, in our case medical terminologies. For the pre-trained model, we select NLLB (facebook/nllb-200-distilled-600M). Next, we show the training script we use:

import torch
from datasets import Dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)
import evaluate
import numpy as np
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

# --- Configuration ---
MODEL_CHECKPOINT = "facebook/nllb-200-distilled-600M"
# MODEL_CHECKPOINT = "facebook/nllb-200-1.3B"
SOURCE_LANG = "eng_Latn"  # NLLB code for English
TARGET_LANG = "fra_Latn"  # NLLB code for French
# Replace these with your actual file paths
# SRC_FILE_PATH = "data/medterms_plus_snomed.en"
# TGT_FILE_PATH = "data/medterms_plus_snomed.fr"
# OUTPUT_DIR = "./nllb-medical-en-fr_SNOMED"

SRC_FILE_PATH = "data/initial_unique_english.txt"
TGT_FILE_PATH = "data/initial_unique_french.txt"
OUTPUT_DIR = "./nllb-medical-en-fr_INITIAL"

def load_data_from_files(src_path, tgt_path):
    """
    Reads two line-aligned files and converts them into a Hugging Face Dataset.
    """
    with open(src_path, "r", encoding="utf-8") as f_src, \
         open(tgt_path, "r", encoding="utf-8") as f_tgt:

        src_lines = [line.strip() for line in f_src if line.strip()]
        tgt_lines = [line.strip() for line in f_tgt if line.strip()]

    if len(src_lines) != len(tgt_lines):
        raise ValueError(f"Mismatch in line counts! Source: {len(src_lines)}, Target: {len(tgt_lines)}")

    data = [{"translation": {SOURCE_LANG: s, TARGET_LANG: t}} for s, t in zip(src_lines, tgt_lines)]
    return Dataset.from_list(data)

def main():
    # 1. Load and Split Data
    print("Loading data...")
    dataset = load_data_from_files(SRC_FILE_PATH, TGT_FILE_PATH)

    # Split into train and validation (90% train, 10% validation)
    dataset = dataset.train_test_split(test_size=0.1)

    # 2. Initialize Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_CHECKPOINT,
        src_lang=SOURCE_LANG,
        tgt_lang=TARGET_LANG
    )

    # 3. Preprocessing Function
    max_input_length = 128
    max_target_length = 128

    def preprocess_function(examples):
        inputs = [ex[SOURCE_LANG] for ex in examples["translation"]]
        targets = [ex[TARGET_LANG] for ex in examples["translation"]]

        # Tokenize inputs
        model_inputs = tokenizer(
            inputs,
            max_length=max_input_length,
            truncation=True
        )

        # Tokenize targets
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                targets,
                max_length=max_target_length,
                truncation=True
            )

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    print("Tokenizing data...")
    tokenized_datasets = dataset.map(preprocess_function, batched=True)

    # 4. Load Model
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT, attn_implementation="sdpa")

    model.config.use_cache = False
    # 5. Define Metric (SacreBLEU)
    metric = evaluate.load("sacrebleu")

    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]

        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

        # Replace -100 in the labels as we can't decode them
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Post-processing for SacreBLEU
        decoded_preds = [pred.strip() for pred in decoded_preds]
        decoded_labels = [[label.strip()] for label in decoded_labels]

        result = metric.compute(predictions=decoded_preds, references=decoded_labels)
        return {"bleu": result["score"]}

    # 6. Training Arguments
    args = Seq2SeqTrainingArguments(
        output_dir=OUTPUT_DIR,
        eval_strategy="no",
        save_strategy="steps",
        save_steps=1000,
        logging_steps=100,
        learning_rate=3e-5,               # Slightly higher LR is okay for smaller models

        # FIX 3: Lower batch size for stability, use accumulation for speed
        per_device_train_batch_size=12,   # Low and steady
        gradient_accumulation_steps=8,  # Total batch still 32
        per_device_eval_batch_size=4,

        # FIX 4: Optimization for Windows/3090
        dataloader_num_workers=0,       # DO NOT use > 0 on Windows for this task
        group_by_length=False,          # Disable this to prevent memory spikes
        fp16=True,
        predict_with_generate=False,
        weight_decay=0.01,
        num_train_epochs=1,
        optim="adamw_torch_fused",
    )

    # 7. Data Collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=-100 # Ignore padding in loss calculation
    )

    # 8. Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    # 9. Train
    print("Starting training...")
    trainer.train()

    # 10. Save the final model
    print(f"Saving model to {OUTPUT_DIR}")
    trainer.save_model(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)

if __name__ == "__main__":
    main()

Loss and Stopping Criteria

The script uses the default sequence-to-sequence training objective from Seq2SeqTrainer:

  • Token-level cross-entropy loss over decoder outputs.

  • Padding tokens are excluded from loss with label_pad_token_id=-100 in DataCollatorForSeq2Seq.

  • No label smoothing is configured in Seq2SeqTrainingArguments (default behavior).

Training stop condition in this configuration:

  • Fixed-duration training with num_train_epochs=1.

  • No early stopping callback is used.

  • Checkpoints are saved every save_steps=1000, but stopping is controlled by the epoch limit.