|
1 | 1 |
|
2 | 2 | from urllib.request import urlopen
|
| 3 | +from datetime import datetime |
3 | 4 |
|
4 |
| -import tensorflow.compat.v1 as tf |
| 5 | +import tensorflow as tf |
5 | 6 |
|
6 | 7 | from PIL import Image
|
7 | 8 | import numpy as np
|
8 |
| -# import scipy |
9 |
| -# from scipy import misc |
10 | 9 | import sys
|
11 |
| -import os |
12 | 10 |
|
13 | 11 | filename = 'model.pb'
|
14 | 12 | labels_filename = 'labels.txt'
|
15 | 13 |
|
16 |
| -mean_values_b_g_r = (0, 0, 0) |
| 14 | +network_input_size = 0 |
17 | 15 |
|
18 |
| -size = (256, 256) |
19 | 16 | output_layer = 'loss:0'
|
20 | 17 | input_node = 'Placeholder:0'
|
21 | 18 |
|
22 |
| -graph_def = tf.GraphDef() |
| 19 | +graph_def = tf.compat.v1.GraphDef() |
23 | 20 | labels = []
|
24 | 21 |
|
25 |
| - |
26 | 22 | def initialize():
|
27 |
| - print('Loading model...', end=''), |
28 |
| - with tf.gfile.FastGFile(filename, 'rb') as f: |
| 23 | + print('Loading model...',end=''), |
| 24 | + with open(filename, 'rb') as f: |
29 | 25 | graph_def.ParseFromString(f.read())
|
30 | 26 | tf.import_graph_def(graph_def, name='')
|
| 27 | + |
| 28 | + # Retrieving 'network_input_size' from shape of 'input_node' |
| 29 | + with tf.compat.v1.Session() as sess: |
| 30 | + input_tensor_shape = sess.graph.get_tensor_by_name(input_node).shape.as_list() |
| 31 | + |
| 32 | + assert len(input_tensor_shape) == 4 |
| 33 | + assert input_tensor_shape[1] == input_tensor_shape[2] |
| 34 | + |
| 35 | + global network_input_size |
| 36 | + network_input_size = input_tensor_shape[1] |
| 37 | + |
31 | 38 | print('Success!')
|
32 | 39 | print('Loading labels...', end='')
|
33 | 40 | with open(labels_filename, 'rt') as lf:
|
34 |
| - for l in lf: |
35 |
| - l = l[:-1] |
36 |
| - labels.append(l) |
| 41 | + global labels |
| 42 | + labels = [l.strip() for l in lf.readlines()] |
37 | 43 | print(len(labels), 'found. Success!')
|
38 | 44 |
|
39 |
| - |
40 |
| -def crop_center(img, cropx, cropy): |
41 |
| - y, x, z = img.shape |
42 |
| - startx = x//2-(cropx//2) |
43 |
| - starty = y//2-(cropy//2) |
44 |
| - print('crop_center: ', x, 'x', y, 'to', cropx, 'x', cropy) |
| 45 | +def log_msg(msg): |
| 46 | + print("{}: {}".format(datetime.now(),msg)) |
| 47 | + |
| 48 | +def extract_bilinear_pixel(img, x, y, ratio, xOrigin, yOrigin): |
| 49 | + xDelta = (x + 0.5) * ratio - 0.5 |
| 50 | + x0 = int(xDelta) |
| 51 | + xDelta -= x0 |
| 52 | + x0 += xOrigin |
| 53 | + if x0 < 0: |
| 54 | + x0 = 0; |
| 55 | + x1 = 0; |
| 56 | + xDelta = 0.0; |
| 57 | + elif x0 >= img.shape[1]-1: |
| 58 | + x0 = img.shape[1]-1; |
| 59 | + x1 = img.shape[1]-1; |
| 60 | + xDelta = 0.0; |
| 61 | + else: |
| 62 | + x1 = x0 + 1; |
| 63 | + |
| 64 | + yDelta = (y + 0.5) * ratio - 0.5 |
| 65 | + y0 = int(yDelta) |
| 66 | + yDelta -= y0 |
| 67 | + y0 += yOrigin |
| 68 | + if y0 < 0: |
| 69 | + y0 = 0; |
| 70 | + y1 = 0; |
| 71 | + yDelta = 0.0; |
| 72 | + elif y0 >= img.shape[0]-1: |
| 73 | + y0 = img.shape[0]-1; |
| 74 | + y1 = img.shape[0]-1; |
| 75 | + yDelta = 0.0; |
| 76 | + else: |
| 77 | + y1 = y0 + 1; |
| 78 | + |
| 79 | + #Get pixels in four corners |
| 80 | + bl = img[y0, x0] |
| 81 | + br = img[y0, x1] |
| 82 | + tl = img[y1, x0] |
| 83 | + tr = img[y1, x1] |
| 84 | + #Calculate interpolation |
| 85 | + b = xDelta * br + (1. - xDelta) * bl |
| 86 | + t = xDelta * tr + (1. - xDelta) * tl |
| 87 | + pixel = yDelta * t + (1. - yDelta) * b |
| 88 | + return pixel |
| 89 | + |
| 90 | +def extract_and_resize(img, targetSize): |
| 91 | + determinant = img.shape[1] * targetSize[0] - img.shape[0] * targetSize[1] |
| 92 | + if determinant < 0: |
| 93 | + ratio = float(img.shape[1]) / float(targetSize[1]) |
| 94 | + xOrigin = 0 |
| 95 | + yOrigin = int(0.5 * (img.shape[0] - ratio * targetSize[0])) |
| 96 | + elif determinant > 0: |
| 97 | + ratio = float(img.shape[0]) / float(targetSize[0]) |
| 98 | + xOrigin = int(0.5 * (img.shape[1] - ratio * targetSize[1])) |
| 99 | + yOrigin = 0 |
| 100 | + else: |
| 101 | + ratio = float(img.shape[0]) / float(targetSize[0]) |
| 102 | + xOrigin = 0 |
| 103 | + yOrigin = 0 |
| 104 | + resize_image = np.empty((targetSize[0], targetSize[1], img.shape[2]), dtype=np.float32) |
| 105 | + for y in range(targetSize[0]): |
| 106 | + for x in range(targetSize[1]): |
| 107 | + resize_image[y, x] = extract_bilinear_pixel(img, x, y, ratio, xOrigin, yOrigin) |
| 108 | + return resize_image |
| 109 | + |
| 110 | +def extract_and_resize_to_256_square(image): |
| 111 | + h, w = image.shape[:2] |
| 112 | + log_msg("crop_center: " + str(w) + "x" + str(h) +" and resize to " + str(256) + "x" + str(256)) |
| 113 | + return extract_and_resize(image, (256, 256)) |
| 114 | + |
| 115 | +def crop_center(img,cropx,cropy): |
| 116 | + h, w = img.shape[:2] |
| 117 | + startx = max(0, w//2-(cropx//2)) |
| 118 | + starty = max(0, h//2-(cropy//2)) |
| 119 | + log_msg("crop_center: " + str(w) + "x" + str(h) +" to " + str(cropx) + "x" + str(cropy)) |
45 | 120 | return img[starty:starty+cropy, startx:startx+cropx]
|
46 | 121 |
|
| 122 | +def resize_down_to_1600_max_dim(image): |
| 123 | + w,h = image.size |
| 124 | + if h < 1600 and w < 1600: |
| 125 | + return image |
| 126 | + |
| 127 | + new_size = (1600 * w // h, 1600) if (h > w) else (1600, 1600 * h // w) |
| 128 | + log_msg("resize: " + str(w) + "x" + str(h) + " to " + str(new_size[0]) + "x" + str(new_size[1])) |
| 129 | + if max(new_size) / max(image.size) >= 0.5: |
| 130 | + method = Image.BILINEAR |
| 131 | + else: |
| 132 | + method = Image.BICUBIC |
| 133 | + return image.resize(new_size, method) |
47 | 134 |
|
48 | 135 | def predict_url(imageUrl):
|
49 |
| - print('Predicting from url: ', imageUrl) |
| 136 | + log_msg("Predicting from url: " +imageUrl) |
50 | 137 | with urlopen(imageUrl) as testImage:
|
51 |
| - # image = scipy.misc.imread(testImage) |
52 | 138 | image = Image.open(testImage)
|
53 | 139 | return predict_image(image)
|
54 | 140 |
|
55 |
| - |
| 141 | +def convert_to_nparray(image): |
| 142 | + # RGB -> BGR |
| 143 | + log_msg("Convert to numpy array") |
| 144 | + image = np.array(image) |
| 145 | + return image[:, :, (2,1,0)] |
| 146 | + |
| 147 | +def update_orientation(image): |
| 148 | + exif_orientation_tag = 0x0112 |
| 149 | + if hasattr(image, '_getexif'): |
| 150 | + exif = image._getexif() |
| 151 | + if exif != None and exif_orientation_tag in exif: |
| 152 | + orientation = exif.get(exif_orientation_tag, 1) |
| 153 | + log_msg('Image has EXIF Orientation: ' + str(orientation)) |
| 154 | + # orientation is 1 based, shift to zero based and flip/transpose based on 0-based values |
| 155 | + orientation -= 1 |
| 156 | + if orientation >= 4: |
| 157 | + image = image.transpose(Image.TRANSPOSE) |
| 158 | + if orientation == 2 or orientation == 3 or orientation == 6 or orientation == 7: |
| 159 | + image = image.transpose(Image.FLIP_TOP_BOTTOM) |
| 160 | + if orientation == 1 or orientation == 2 or orientation == 5 or orientation == 6: |
| 161 | + image = image.transpose(Image.FLIP_LEFT_RIGHT) |
| 162 | + return image |
| 163 | + |
56 | 164 | def predict_image(image):
|
57 |
| - print('Predicting image') |
58 |
| - tf.reset_default_graph() |
59 |
| - tf.import_graph_def(graph_def, name='') |
60 |
| - |
61 |
| - with tf.Session() as sess: |
62 |
| - prob_tensor = sess.graph.get_tensor_by_name(output_layer) |
63 |
| - |
64 |
| - input_tensor_shape = sess.graph.get_tensor_by_name( |
65 |
| - 'Placeholder:0').shape.as_list() |
66 |
| - network_input_size = input_tensor_shape[1] |
67 |
| - |
68 |
| - # w = image.shape[0] |
69 |
| - # h = image.shape[1] |
70 |
| - w, h = image.size |
71 |
| - print('Image size', w, 'x', h) |
72 |
| - |
73 |
| - # scaling |
74 |
| - if w > h: |
75 |
| - new_size = (int((float(size[1]) / h) * w), size[1], 3) |
76 |
| - else: |
77 |
| - new_size = (size[0], int((float(size[0]) / w) * h), 3) |
78 |
| - |
79 |
| - # resize |
80 |
| - if not (new_size[0] == w and new_size[0] == h): |
81 |
| - print('Resizing to', new_size[0], 'x', new_size[1]) |
82 |
| - #augmented_image = scipy.misc.imresize(image, new_size) |
83 |
| - augmented_image = np.asarray( |
84 |
| - image.resize((new_size[0], new_size[1]))) |
85 |
| - else: |
86 |
| - augmented_image = np.asarray(image) |
87 |
| - |
88 |
| - # crop center |
89 |
| - try: |
90 |
| - augmented_image = crop_center( |
91 |
| - augmented_image, network_input_size, network_input_size) |
92 |
| - except: |
93 |
| - return 'error: crop_center' |
94 |
| - |
95 |
| - augmented_image = augmented_image.astype(float) |
96 |
| - |
97 |
| - # RGB -> BGR |
98 |
| - red, green, blue = tf.split( |
99 |
| - axis=2, num_or_size_splits=3, value=augmented_image) |
100 |
| - |
101 |
| - image_normalized = tf.concat(axis=2, values=[ |
102 |
| - blue - mean_values_b_g_r[0], |
103 |
| - green - mean_values_b_g_r[1], |
104 |
| - red - mean_values_b_g_r[2], |
105 |
| - ]) |
106 |
| - |
107 |
| - image_normalized = image_normalized.eval() |
108 |
| - image_normalized = np.expand_dims(image_normalized, axis=0) |
109 |
| - |
110 |
| - predictions, = sess.run(prob_tensor, {input_node: image_normalized}) |
111 |
| - |
112 |
| - result = [] |
113 |
| - idx = 0 |
114 |
| - for p in predictions: |
115 |
| - truncated_probablity = np.float64(round(p, 8)) |
116 |
| - if (truncated_probablity > 1e-8): |
117 |
| - result.append( |
118 |
| - {'Tag': labels[idx], 'Probability': truncated_probablity}) |
119 |
| - idx += 1 |
120 |
| - print('Results: ', str(result)) |
121 |
| - return result |
| 165 | + |
| 166 | + log_msg('Predicting image') |
| 167 | + try: |
| 168 | + if image.mode != "RGB": |
| 169 | + log_msg("Converting to RGB") |
| 170 | + image = image.convert("RGB") |
| 171 | + |
| 172 | + w,h = image.size |
| 173 | + log_msg("Image size: " + str(w) + "x" + str(h)) |
| 174 | + |
| 175 | + # Update orientation based on EXIF tags |
| 176 | + image = update_orientation(image) |
| 177 | + |
| 178 | + # If the image has either w or h greater than 1600 we resize it down respecting |
| 179 | + # aspect ratio such that the largest dimention is 1600 |
| 180 | + image = resize_down_to_1600_max_dim(image) |
| 181 | + |
| 182 | + # Convert image to numpy array |
| 183 | + image = convert_to_nparray(image) |
| 184 | + |
| 185 | + # Crop the center square and resize that square down to 256x256 |
| 186 | + resized_image = extract_and_resize_to_256_square(image) |
| 187 | + |
| 188 | + # Crop the center for the specified network_input_Size |
| 189 | + cropped_image = crop_center(resized_image, network_input_size, network_input_size) |
| 190 | + |
| 191 | + tf.compat.v1.reset_default_graph() |
| 192 | + tf.import_graph_def(graph_def, name='') |
| 193 | + |
| 194 | + with tf.compat.v1.Session() as sess: |
| 195 | + prob_tensor = sess.graph.get_tensor_by_name(output_layer) |
| 196 | + predictions, = sess.run(prob_tensor, {input_node: [cropped_image] }) |
| 197 | + |
| 198 | + result = [] |
| 199 | + for p, label in zip(predictions, labels): |
| 200 | + truncated_probablity = np.float64(round(p,8)) |
| 201 | + if truncated_probablity > 1e-8: |
| 202 | + result.append({ |
| 203 | + 'tagName': label, |
| 204 | + 'probability': truncated_probablity, |
| 205 | + 'tagId': '', |
| 206 | + 'boundingBox': None }) |
| 207 | + |
| 208 | + response = { |
| 209 | + 'id': '', |
| 210 | + 'project': '', |
| 211 | + 'iteration': '', |
| 212 | + 'created': datetime.utcnow().isoformat(), |
| 213 | + 'predictions': result |
| 214 | + } |
| 215 | + |
| 216 | + log_msg("Results: " + str(response)) |
| 217 | + return response |
| 218 | + |
| 219 | + except Exception as e: |
| 220 | + log_msg(str(e)) |
| 221 | + return 'Error: Could not preprocess image for prediction. ' + str(e) |
0 commit comments