-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlite.py
114 lines (82 loc) · 4 KB
/
lite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import tensorflow as tf
import numpy as np
from official.nlp.data import classifier_data_lib as cdl
from official.nlp.bert import tokenization
from tensorflow_hub import registry
import os
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
uri = "https://tfhub.dev/google/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT/1"
vocab_file = os.path.join(
registry.resolver(uri), 'assets', 'vocab.txt')
tokenizer = tokenization.FullTokenizer(vocab_file, True)
def to_feature(text, label=None, label_list=None, max_seq_length=128, tokenizer=tokenizer):
example = cdl.InputExample(guid=None,
text_a=text,
text_b=None,
label=label)
feature = cdl.convert_single_example(
0, example, label_list, max_seq_length, tokenizer)
return (np.array(feature.input_ids), np.array(feature.input_mask), np.array(feature.segment_ids))
def _get_input_tensor(input_tensors, input_details, i):
"""Gets input tensor in `input_tensors` that maps `input_detail[i]`."""
if isinstance(input_tensors, dict):
# Gets the mapped input tensor.
input_detail = input_details[i]
for input_tensor_name, input_tensor in input_tensors.items():
if input_tensor_name in input_detail['name']:
return input_tensor
raise ValueError('Input tensors don\'t contains a tensor that mapped the '
'input detail %s' % str(input_detail))
else:
return input_tensors[i]
class LiteRunner(object):
def __init__(self,tflite_filepath):
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
tflite_model = f.read()
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.interpreter.allocate_tensors()
# Gets the indexed of the input tensors.
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
def run(self, input_tensors):
if not isinstance(input_tensors, list) and \
not isinstance(input_tensors, tuple) and \
not isinstance(input_tensors, dict):
input_tensors = [input_tensors]
interpreter = self.interpreter
# Reshape inputs
for i, input_detail in enumerate(self.input_details):
input_tensor = _get_input_tensor(input_tensors, self.input_details, i)
interpreter.resize_tensor_input(input_detail['index'], input_tensor.shape)
interpreter.allocate_tensors()
# Feed input to the interpreter
for i, input_detail in enumerate(self.input_details):
input_tensor = _get_input_tensor(input_tensors, self.input_details, i)
if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
# Quantize the input
scale, zero_point = input_detail['quantization']
input_tensor = input_tensor / scale + zero_point
input_tensor = np.array(input_tensor, dtype=input_detail['dtype'])
interpreter.set_tensor(input_detail['index'], input_tensor.astype(np.int32))
interpreter.invoke()
output_tensors = []
for output_detail in self.output_details:
output_tensor = interpreter.get_tensor(output_detail['index'])
if output_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
# Dequantize the output
scale, zero_point = output_detail['quantization']
output_tensor = output_tensor.astype(np.float32)
output_tensor = (output_tensor - zero_point) * scale
output_tensors.append(output_tensor)
if len(output_tensors) == 1:
return output_tensors[0]
return output_tensors
def predict(text):
encoded = to_feature(text)
runner = LiteRunner('lite/model.tflite')
preds = runner.run(encoded)
pred = np.argmax(preds)
if pred == 0:
return 'Non-sucide'
else:
return 'Sucide'