-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlex_rank.py
137 lines (105 loc) · 4.79 KB
/
lex_rank.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from itertools import combinations
import networkx as nx
from sentence_transformers import SentenceTransformer, util
import numpy as np
from scipy.linalg import norm
from scipy.sparse.csgraph import connected_components
from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm
import razdel
from pprint import pprint
class LexRank:
def __init__(self, model_name:str):
self.encoder = SentenceTransformer(model_name)
def get_lexrank_summary(self, text, n_sentences_summary=3):
# Разбиваем текст на предложения
sentences = [sentence.text for sentence in razdel.sentenize(text)]
n_sentences = len(sentences)
embeddings = self.encoder.encode(sentences, convert_to_tensor=True)
cos_scores = util.pytorch_cos_sim(embeddings, embeddings).detach().cpu().numpy()
centrality_scores = self.degree_centrality_scores(cos_scores)
most_central_sentence_indices = np.argsort(-centrality_scores)
predicted_summary = ' '.join(
[sentences[idx].strip() for idx in most_central_sentence_indices[0:n_sentences_summary]])
return predicted_summary
def get_summary(self, records, n_sentences_summary=3, show_full_text=True, only_blue=False):
references = []
predictions = []
if type(records) == dict:
summary = records["summary"]
references.append(summary)
text = records["text"]
predicted_summary = self.get_lexrank_summary(text, n_sentences_summary)
predictions.append(predicted_summary)
if only_blue:
return self.calc_scores(references, predictions, text, only_blue)
else:
self.calc_scores(references, predictions, text, only_blue)
elif type(records) == str:
predicted_summary = self.get_lexrank_summary(records, n_sentences_summary)
predictions.append(predicted_summary)
if show_full_text:
print('Полный текст:')
pprint(records, width=150)
print('-' * 150)
print("LextRank summary:")
pprint(predictions[-1], width=150)
def degree_centrality_scores(self, similar_matrix, increase_power=True):
markow_matrix = self.create_markow_matrix(similar_matrix)
scores = self.stationary_dist(markow_matrix, increase_power=increase_power)
return scores
def power_method(self, transition_matrix, increase_power=True):
eigenvectors = np.ones(len(transition_matrix))
if len(eigenvectors) == 1:
return eigenvectors
transition = transition_matrix.transpose()
while True:
eigenvectors_next = np.dot(transition, eigenvectors)
if np.allclose(eigenvectors_next, eigenvectors):
return eigenvectors_next
eigenvectors = eigenvectors_next
if increase_power:
transition = np.dot(transition, transition)
def connected_nodes(self, transition_matrix):
_, labels = connected_components(transition_matrix)
groups = []
for tag in np.unique(labels):
group = np.where(labels == tag)[0]
groups.append(group)
return groups
def stationary_dist(self, transition_matrix, increase_power=True):
n_1, n_2 = transition_matrix.shape
if n_1 != n_2:
raise ValueError('\'transition_matrix\' should be square')
distribution = np.zeros(n_1)
group_idx = self.connected_nodes(transition_matrix)
for group in group_idx:
transition_matrix = transition_matrix[np.ix_(group, group)]
eigenvectors = self.power_method(transition_matrix, increase_power=increase_power)
distribution[group] = eigenvectors
return distribution
def create_markow_matrix(self, similar_matrix):
n_1, n_2 = similar_matrix.shape
if n_1 != n_2:
raise ValueError('\'similar_matrix\' should be square')
row_sum = similar_matrix.sum(axis=1, keepdims=True)
return similar_matrix / row_sum
@staticmethod
def calc_scores(references, predictions, text, only_blue=False):
if only_blue:
return corpus_bleu([[r] for r in references], predictions)
else:
print()
print("Count:", len(predictions))
print('Полный текст:')
pprint(text, width=150)
print('-' * 150)
print("Исходное summary:")
pprint(references[-1], width=150)
print('-' * 150)
print("LexRank summary:")
pprint(predictions[-1], width=150)
print('-' * 150)
print("BLEU: ", corpus_bleu([[r] for r in references], predictions))
if __name__ == '__main__':
pass