# calculate_similarity.py import sys import torch from transformers import BertTokenizer, BertModel from torch.nn.functional import cosine_similarity, normalize tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') model = BertModel.from_pretrained('bert-base-chinese') def calculate_text_similarity(text1, text2): encoded_inputs = tokenizer([text1, text2], padding=True, truncation=True, return_tensors="pt") with torch.no_grad(): outputs = model(**encoded_inputs) last_hidden_states = outputs.last_hidden_state cls_vectors = last_hidden_states[:, 0, :] cls_vectors_normalized = normalize(cls_vectors, p=2, dim=1) similarity = cosine_similarity(cls_vectors_normalized[0:1], cls_vectors_normalized[1:2]).item() return similarity if __name__ == "__main__": text1 = sys.argv[1] text2 = sys.argv[2] print(calculate_text_similarity(text1, text2))