calculate_similarity.py 956 B

12345678910111213141516171819202122232425
  1. # calculate_similarity.py
  2. import sys
  3. import torch
  4. from transformers import BertTokenizer, BertModel
  5. from torch.nn.functional import cosine_similarity, normalize
  6. tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
  7. model = BertModel.from_pretrained('bert-base-chinese')
  8. def calculate_text_similarity(text1, text2):
  9. encoded_inputs = tokenizer([text1, text2], padding=True, truncation=True, return_tensors="pt")
  10. with torch.no_grad():
  11. outputs = model(**encoded_inputs)
  12. last_hidden_states = outputs.last_hidden_state
  13. cls_vectors = last_hidden_states[:, 0, :]
  14. cls_vectors_normalized = normalize(cls_vectors, p=2, dim=1)
  15. similarity = cosine_similarity(cls_vectors_normalized[0:1], cls_vectors_normalized[1:2]).item()
  16. return similarity
  17. if __name__ == "__main__":
  18. text1 = sys.argv[1]
  19. text2 = sys.argv[2]
  20. print(calculate_text_similarity(text1, text2))