12345678910111213141516171819202122232425 |
- # calculate_similarity.py
- import sys
- import torch
- from transformers import BertTokenizer, BertModel
- from torch.nn.functional import cosine_similarity, normalize
-
- tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
- model = BertModel.from_pretrained('bert-base-chinese')
-
- def calculate_text_similarity(text1, text2):
- encoded_inputs = tokenizer([text1, text2], padding=True, truncation=True, return_tensors="pt")
- with torch.no_grad():
- outputs = model(**encoded_inputs)
- last_hidden_states = outputs.last_hidden_state
- cls_vectors = last_hidden_states[:, 0, :]
- cls_vectors_normalized = normalize(cls_vectors, p=2, dim=1)
- similarity = cosine_similarity(cls_vectors_normalized[0:1], cls_vectors_normalized[1:2]).item()
- return similarity
-
- if __name__ == "__main__":
- text1 = sys.argv[1]
- text2 = sys.argv[2]
- print(calculate_text_similarity(text1, text2))
-
|