Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added data/vg-30.pb
Binary file not shown.
Binary file added input/jaguar.mp4
Binary file not shown.
74 changes: 74 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging
import sys
from datetime import datetime

import tensorflow as tf
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"

'''
1. import `Video` and `save_video` from the correct module of package "styler"
'''
from styler.video import Video
from styler.utils import save_video


model_file = 'data/vg-30.pb'
model_name = 'vg-30'
logging.basicConfig(
stream=sys.stdout,
format='%(asctime)s %(levelname)s:%(message)s',
level=logging.INFO,
datefmt='%I:%M:%S')


def main():

with open(model_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def)
graph = tf.get_default_graph()

with tf.Session(config=tf.ConfigProto(
intra_op_parallelism_threads=4)) as session:

logging.info("Initializing graph")
session.run(tf.global_variables_initializer())

image = graph.get_tensor_by_name("import/%s/image_in:0" % model_name)
out = graph.get_tensor_by_name("import/%s/output:0" % model_name)
shape = image.get_shape().as_list()

'''
2. set the `path` to your input
'''
with Video('input/jaguar.mp4') as v:
frames = v.read_frames(image_h=shape[1], image_w=shape[2])

logging.info("Processing image")
start_time = datetime.now()

'''
3. Write a list comprehension to iterate through all frames,
and make it be processed by Tensorflow.
'''
processed = [
session.run(out, feed_dict={image: [frame]})
for frame in frames
]

'''
4. Pass the results as a argument into function
'''
save_video('result.avi',
fps=30, h=shape[1], w=shape[2],
frames=processed)

logging.info("Processing took %f" % (
(datetime.now() - start_time).total_seconds()))
logging.info("Done")


if __name__ == '__main__':
main()
Binary file added result.avi
Binary file not shown.
Binary file added styler/__pycache__/utils.cpython-36.pyc
Binary file not shown.
Binary file added styler/__pycache__/video.cpython-36.pyc
Binary file not shown.
48 changes: 48 additions & 0 deletions styler/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import cv2
import numpy as np

MEAN_VALUES = np.array([123, 117, 104]).reshape((1,1,3))


def resize(img, image_h, image_w, zoom=False):
# crop image from center
ratio = float(image_h) / image_w
height, width = int(img.shape[0]), int(img.shape[1])
yy, xx = 0, 0
if height > width * ratio: #too tall
yy = int(height - width * ratio) // 2
height = int(width * ratio)
else: # too wide
xx = int(width - height / ratio) // 2
width = int(height / ratio)
if zoom:
yy += int(height / 6)
xx += int(height / 6)
height = int(height * 2 / 3)
width = int(width * 2 / 3)
crop_img = img[yy:yy + height, xx:xx + width]
# resize
resized_img = cv2.resize(crop_img, (image_h, image_w))
centered_img = resized_img - MEAN_VALUES

return centered_img


def save_video(filepath, fps, w, h, frames):
codecs = ['WMV1', 'MJPG', 'XVID', 'PIM1']
'''
If you cannot write video file, you may change the used codec
'''
used_codec = codecs[2] # change the index from codecs
fourcc = cv2.VideoWriter_fourcc(*used_codec)
out = cv2.VideoWriter(filepath, fourcc, fps, (w, h))
for frame in frames:
f = frame[0, :, :, :]
out.write(post_process(f))
out.release()


def post_process(image):
img = image + MEAN_VALUES
img = np.clip(img, 0, 255).astype('uint8')
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
47 changes: 47 additions & 0 deletions styler/video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import cv2
import time

from styler.utils import resize


class Video:

def __init__(self, path):
self.path = path
self.cap = cv2.VideoCapture(self.path)
self.frames = []

def __enter__(self):
if not self.cap.isOpened():
raise Exception('Cannot open video: {}'.format(self.path))
return self

def __len__(self):
return len(self.frames)

def read_frames(self, image_h, image_w):
'''
5.
- Read video frames from `self.cap` and collect frames into list
- Apply `resize()` on each frame before add it to list
- Also assign frames to "self" object
- Return your results
'''
frames = []
while self.cap.isOpened():
ret, cap_frame = self.cap.read()
if ret == True:
out = resize(cap_frame, image_h=int(image_h), image_w=int(image_w))
frames.append(out)
# 5-1 /5-2 Read video and collect them

# Break the loop
else:
break

print(len(frames))
self.frames = frames # 5-3 let object have the result
return frames # return your results

def __exit__(self, exc_type, exc_val, exc_tb):
self.cap.release()