Skip to content

Commit 0ca2916

Browse files
committed
video: add crop, bug fix
1 parent cdb1fdb commit 0ca2916

File tree

1 file changed

+70
-13
lines changed

1 file changed

+70
-13
lines changed

dlclivegui/video.py

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
def create_labeled_video(video_file,
1919
ts_file,
2020
dlc_file,
21-
out_file=None,
21+
out_dir=None,
2222
save_images=False,
2323
cut=(0, np.Inf),
24-
cmap='bgy',
24+
crop=None,
25+
cmap='bmy',
2526
radius=3,
2627
lik_thresh=0.5,
2728
write_ts=False,
@@ -46,7 +47,7 @@ def create_labeled_video(video_file,
4647
cut : tuple, optional
4748
time of video to use. Will only save labeled video for time after cut[0] and before cut[1], by default (0, np.Inf)
4849
cmap : str, optional
49-
a :package:`colorcet` colormap, by default 'bgy'
50+
a :package:`colorcet` colormap, by default 'bmy'
5051
radius : int, optional
5152
radius for keypoints, by default 3
5253
lik_thresh : float, optional
@@ -68,26 +69,34 @@ def create_labeled_video(video_file,
6869

6970

7071
lab = "LABELED" if label else "UNLABELED"
71-
if out_file:
72-
out_file = f"{out_file}_VIDEO_{lab}.avi"
73-
out_times_file = f"{out_file}_TS_{lab}.npy"
72+
if out_dir:
73+
out_file = f"{out_dir}/{os.path.splitext(os.path.basename(video_file))[0]}_{lab}.avi"
74+
out_times_file = f"{out_dir}/{os.path.splitext(os.path.basename(ts_file))[0]}_{lab}.npy"
7475
else:
7576
out_file = f"{os.path.splitext(video_file)[0]}_{lab}.avi"
76-
out_times_file = f"{os.path.splitext(ts_file)[0]}_L{lab}.npy"
77+
out_times_file = f"{os.path.splitext(ts_file)[0]}_{lab}.npy"
78+
79+
os.makedirs(os.path.normpath(os.path.dirname(out_file)), exist_ok=True)
7780

7881
if save_images:
7982
im_dir = os.path.splitext(out_file)[0]
8083
os.makedirs(im_dir, exist_ok=True)
8184

85+
im_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
86+
if crop is not None:
87+
crop = np.max(np.vstack((crop, [0, im_size[1], 0, im_size[0]])), axis=0)
88+
im_size = (crop[3]-crop[2], crop[1]-crop[0])
89+
8290
fourcc = cv2.VideoWriter_fourcc(*'DIVX')
8391
fps = cap.get(cv2.CAP_PROP_FPS)
84-
im_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
8592
vwriter = cv2.VideoWriter(out_file, fourcc, fps, im_size)
8693
label_times = []
87-
94+
8895
if write_ts:
8996
ts_font = cv2.FONT_HERSHEY_PLAIN
90-
ts_coord = (0, im_size[1])
97+
ts_w = 0 if crop is None else crop[0]
98+
ts_h = im_size[1] if crop is None else crop[1]
99+
ts_coord = (ts_w, ts_h)
91100
ts_color = (255, 255, 255)
92101
ts_size = 2
93102

@@ -114,7 +123,8 @@ def create_labeled_video(video_file,
114123

115124
frame_times_sub = cam_frame_times[(cam_frame_times-cam_frame_times[0] > cut[0]) & (cam_frame_times-cam_frame_times[0] < cut[1])]
116125
iterator = tqdm(range(ind, ind+frame_times_sub.size)) if progress else range(ind, ind+frame_times_sub.size)
117-
126+
this_pose = np.zeros((bodyparts.size, 3))
127+
118128
for i in iterator:
119129

120130
cur_time = cam_frame_times[i]
@@ -124,8 +134,10 @@ def create_labeled_video(video_file,
124134
if not ret:
125135
raise Exception(f"Could not read frame = {i+1} at time = {cur_time-cam_frame_times[0]}.")
126136

127-
cur_pose_time = pose_times[np.where(pose_times - cur_time > 0)[0][0]]
128-
this_pose = poses[poses['pose_time']==cur_pose_time]
137+
poses_before_index = np.where(pose_times < cur_time)[0]
138+
if poses_before_index.size > 0:
139+
cur_pose_time = pose_times[poses_before_index[-1]]
140+
this_pose = poses[poses['pose_time']==cur_pose_time]
129141

130142
if label:
131143
for j in range(bodyparts.size):
@@ -134,6 +146,9 @@ def create_labeled_video(video_file,
134146
x = int(this_bp[0])
135147
y = int(this_bp[1])
136148
frame = cv2.circle(frame, (x, y), radius, colors[j], thickness=-1)
149+
150+
if crop is not None:
151+
frame = frame[crop[0]:crop[1], crop[2]:crop[3]]
137152

138153
if write_ts:
139154
frame = cv2.putText(frame, f"{vid_time:0.3f}", ts_coord, ts_font, write_scale, ts_color, ts_size)
@@ -153,3 +168,45 @@ def create_labeled_video(video_file,
153168

154169
vwriter.release()
155170
np.save(out_times_file, label_times)
171+
172+
173+
def main():
174+
175+
import argparse
176+
import os
177+
178+
parser = argparse.ArgumentParser()
179+
parser.add_argument('file', type=str)
180+
parser.add_argument('-o', '--out-dir', type=str, default=None)
181+
parser.add_argument('-s', '--save-images', action='store_true')
182+
parser.add_argument('-u', '--cut', nargs='+', type=float, default=[0, np.Inf])
183+
parser.add_argument('-c', '--crop', nargs='+', type=int, default=None)
184+
parser.add_argument('-m', '--cmap', type=str, default='bmy')
185+
parser.add_argument('-r', '--radius', type=int, default=3)
186+
parser.add_argument('-l', '--lik-thresh', type=float, default=0.5)
187+
parser.add_argument('-w', '--write-ts', action='store_true')
188+
parser.add_argument('--write-scale', type=int, default=2)
189+
parser.add_argument('-d', '--display', action='store_true')
190+
parser.add_argument('--no-progress', action='store_false')
191+
parser.add_argument('--no-label', action='store_false')
192+
args = parser.parse_args()
193+
194+
vid_file = os.path.normpath(f"{args.file}_VIDEO.avi")
195+
ts_file = os.path.normpath(f"{args.file}_TS.npy")
196+
dlc_file = os.path.normpath(f"{args.file}_DLC.hdf5")
197+
198+
create_labeled_video(vid_file,
199+
ts_file,
200+
dlc_file,
201+
out_dir=args.out_dir,
202+
save_images=args.save_images,
203+
cut=tuple(args.cut),
204+
crop=args.crop,
205+
cmap=args.cmap,
206+
radius=args.radius,
207+
lik_thresh=args.lik_thresh,
208+
write_ts=args.write_ts,
209+
write_scale=args.write_scale,
210+
display=args.display,
211+
progress=args.no_progress,
212+
label=args.no_label)

0 commit comments

Comments
 (0)