From 7372137391e1159f5e38323fea6e8747f3d11b3b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 22 Aug 2017 11:24:34 +0000 Subject: [PATCH] preprocess_img in tf --- osvos.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/osvos.py b/osvos.py index 880838a..413cd3e 100644 --- a/osvos.py +++ b/osvos.py @@ -17,7 +17,7 @@ from PIL import Image slim = tf.contrib.slim - +mean = np.array((104.00699, 116.66877, 122.67892), dtype=np.float32) def osvos_arg_scope(weight_decay=0.0002): """Defines the OSVOS arg scope. @@ -160,6 +160,20 @@ def interp_surgery(variables): return interp_tensors +def preprocess_img_tf(image): + """Preprocess the image to adapt it to network requirements + Args: + Image we want to input the network (W,H,3) numpy array + Returns: + Image ready to input the network (1,W,H,3) + """ + img = tf.cast(image, tf.float32) + img = tf.reverse(img, axis=[-1]) + img = tf.subtract(img, mean) + img = tf.expand_dims(img, 0) + return img + + # TO DO: Move preprocessing into Tensorflow def preprocess_img(image): """Preprocess the image to adapt it to network requirements @@ -544,7 +558,7 @@ def _train(dataset, initial_ckpt, supervison, learning_rate, logs_path, max_trai # Average the gradient for _ in range(0, iter_mean_grad): batch_image, batch_label = dataset.next_batch(batch_size, 'train') - image = preprocess_img(batch_image[0]) + image = preprocess_img_tf(batch_image[0]).eval() label = preprocess_labels(batch_label[0]) run_res = sess.run([total_loss, merged_summary_op] + grad_accumulator_ops, feed_dict={input_image: image, input_label: label}) @@ -564,7 +578,7 @@ def _train(dataset, initial_ckpt, supervison, learning_rate, logs_path, max_trai # Save a checkpoint if step % save_step == 0: if test_image_path is not None: - curr_output = sess.run(img_summary, feed_dict={input_image: preprocess_img(test_image_path)}) + curr_output = sess.run(img_summary, feed_dict={input_image: preprocess_img_tf(test_image_path)}) summary_writer.add_summary(curr_output, step) save_path = saver.save(sess, model_name, global_step=global_step) print "Model saved in file: %s" % save_path @@ -644,7 +658,7 @@ def test(dataset, checkpoint_file, result_path, config=None): for frame in range(0, dataset.get_test_size()): img, curr_img = dataset.next_batch(batch_size, 'test') curr_frame = curr_img[0].split('/')[-1].split('.')[0] + '.png' - image = preprocess_img(img[0]) + image = preprocess_img_tf(img[0]).eval() res = sess.run(probabilities, feed_dict={input_image: image}) res_np = res.astype(np.float32)[0, :, :, 0] > 162.0/255.0 scipy.misc.imsave(os.path.join(result_path, curr_frame), res_np.astype(np.float32))