|
| 1 | +# ------------------------------------------------------------------------------ |
| 2 | +# Copyright (c) ETRI. All rights reserved. |
| 3 | +# Licensed under the BSD 3-Clause License. |
| 4 | +# This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. |
| 5 | +# You can refer to details of AIR project at https://aiforrobots.github.io |
| 6 | +# Written by Youngwoo Yoon (youngwoo@etri.re.kr) |
| 7 | +# ------------------------------------------------------------------------------ |
| 8 | + |
| 9 | +import glob |
| 10 | +import matplotlib |
| 11 | +import cv2 |
| 12 | +import re |
| 13 | +import json |
| 14 | +import _pickle as pickle |
| 15 | +from webvtt import WebVTT |
| 16 | +from config import my_config |
| 17 | + |
| 18 | + |
| 19 | +############################################################################### |
| 20 | +# SKELETON |
| 21 | +def draw_skeleton_on_image(img, skeleton, thickness=15): |
| 22 | + if not skeleton: |
| 23 | + return img |
| 24 | + |
| 25 | + new_img = img.copy() |
| 26 | + for pair in SkeletonWrapper.skeleton_line_pairs: |
| 27 | + pt1 = (int(skeleton[pair[0] * 3]), int(skeleton[pair[0] * 3 + 1])) |
| 28 | + pt2 = (int(skeleton[pair[1] * 3]), int(skeleton[pair[1] * 3 + 1])) |
| 29 | + if pt1[0] == 0 or pt2[1] == 0: |
| 30 | + pass |
| 31 | + else: |
| 32 | + rgb = [v * 255 for v in matplotlib.colors.to_rgba(pair[2])][:3] |
| 33 | + cv2.line(new_img, pt1, pt2, color=rgb[::-1], thickness=thickness) |
| 34 | + |
| 35 | + return new_img |
| 36 | + |
| 37 | + |
| 38 | +def is_list_empty(my_list): |
| 39 | + return all(map(is_list_empty, my_list)) if isinstance(my_list, list) else False |
| 40 | + |
| 41 | + |
| 42 | +def get_closest_skeleton(frame, selected_body): |
| 43 | + """ find the closest one to the selected skeleton """ |
| 44 | + diff_idx = [i * 3 for i in range(8)] + [i * 3 + 1 for i in range(8)] # upper-body |
| 45 | + |
| 46 | + min_diff = 10000000 |
| 47 | + tracked_person = None |
| 48 | + for person in frame: # people |
| 49 | + body = get_skeleton_from_frame(person) |
| 50 | + |
| 51 | + diff = 0 |
| 52 | + n_diff = 0 |
| 53 | + for i in diff_idx: |
| 54 | + if body[i] > 0 and selected_body[i] > 0: |
| 55 | + diff += abs(body[i] - selected_body[i]) |
| 56 | + n_diff += 1 |
| 57 | + if n_diff > 0: |
| 58 | + diff /= n_diff |
| 59 | + if diff < min_diff: |
| 60 | + min_diff = diff |
| 61 | + tracked_person = person |
| 62 | + |
| 63 | + base_distance = max(abs(selected_body[0 * 3 + 1] - selected_body[1 * 3 + 1]) * 3, |
| 64 | + abs(selected_body[2 * 3] - selected_body[5 * 3]) * 2) |
| 65 | + if tracked_person and min_diff > base_distance: # tracking failed |
| 66 | + tracked_person = None |
| 67 | + |
| 68 | + return tracked_person |
| 69 | + |
| 70 | + |
| 71 | +def get_skeleton_from_frame(frame): |
| 72 | + if 'pose_keypoints_2d' in frame: |
| 73 | + return frame['pose_keypoints_2d'] |
| 74 | + elif 'pose_keypoints' in frame: |
| 75 | + return frame['pose_keypoints'] |
| 76 | + else: |
| 77 | + return None |
| 78 | + |
| 79 | + |
| 80 | +class SkeletonWrapper: |
| 81 | + # color names: https://matplotlib.org/mpl_examples/color/named_colors.png |
| 82 | + visualization_line_pairs = [(0, 1, 'b'), (1, 2, 'darkred'), (2, 3, 'r'), (3, 4, 'gold'), (1, 5, 'darkgreen'), (5, 6, 'g'), |
| 83 | + (6, 7, 'lightgreen'), |
| 84 | + (1, 8, 'darkcyan'), (8, 9, 'c'), (9, 10, 'skyblue'), (1, 11, 'deeppink'), (11, 12, 'hotpink'), (12, 13, 'lightpink')] |
| 85 | + skeletons = [] |
| 86 | + skeleton_line_pairs = [(0, 1, 'b'), (1, 2, 'darkred'), (2, 3, 'r'), (3, 4, 'gold'), (1, 5, 'darkgreen'), |
| 87 | + (5, 6, 'g'), (6, 7, 'lightgreen')] |
| 88 | + |
| 89 | + def __init__(self, basepath, vid): |
| 90 | + # load skeleton data (and save it to pickle for next load) |
| 91 | + pickle_file = glob.glob(basepath + '/' + vid + '.pickle') |
| 92 | + |
| 93 | + if pickle_file: |
| 94 | + with open(pickle_file[0], 'rb') as file: |
| 95 | + self.skeletons = pickle.load(file) |
| 96 | + else: |
| 97 | + files = glob.glob(basepath + '/' + vid + '/*.json') |
| 98 | + if len(files) > 10: |
| 99 | + files = sorted(files) |
| 100 | + self.skeletons = [] |
| 101 | + for file in files: |
| 102 | + self.skeletons.append(self.read_skeleton_json(file)) |
| 103 | + with open(basepath + '/' + vid + '.pickle', 'wb') as file: |
| 104 | + pickle.dump(self.skeletons, file) |
| 105 | + else: |
| 106 | + self.skeletons = [] |
| 107 | + |
| 108 | + |
| 109 | + def read_skeleton_json(self, file): |
| 110 | + with open(file) as json_file: |
| 111 | + skeleton_json = json.load(json_file) |
| 112 | + return skeleton_json['people'] |
| 113 | + |
| 114 | + |
| 115 | + def get(self, start_frame_no, end_frame_no, interval=1): |
| 116 | + |
| 117 | + chunk = self.skeletons[start_frame_no:end_frame_no] |
| 118 | + |
| 119 | + if is_list_empty(chunk): |
| 120 | + return [] |
| 121 | + else: |
| 122 | + if interval > 1: |
| 123 | + return chunk[::int(interval)] |
| 124 | + else: |
| 125 | + return chunk |
| 126 | + |
| 127 | + |
| 128 | +############################################################################### |
| 129 | +# VIDEO |
| 130 | +def read_video(base_path, vid): |
| 131 | + files = glob.glob(base_path + '/*' + vid + '.mp4') |
| 132 | + if len(files) == 0: |
| 133 | + return None |
| 134 | + elif len(files) >= 2: |
| 135 | + assert False |
| 136 | + filepath = files[0] |
| 137 | + |
| 138 | + video_obj = VideoWrapper(filepath) |
| 139 | + |
| 140 | + return video_obj |
| 141 | + |
| 142 | + |
| 143 | +class VideoWrapper: |
| 144 | + video = [] |
| 145 | + |
| 146 | + def __init__(self, filepath): |
| 147 | + self.filepath = filepath |
| 148 | + self.video = cv2.VideoCapture(filepath) |
| 149 | + self.total_frames = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT)) |
| 150 | + self.height = self.video.get(cv2.CAP_PROP_FRAME_HEIGHT) |
| 151 | + self.framerate = self.video.get(cv2.CAP_PROP_FPS) |
| 152 | + |
| 153 | + def get_video_reader(self): |
| 154 | + return self.video |
| 155 | + |
| 156 | + def frame2second(self, frame_no): |
| 157 | + return frame_no / self.framerate |
| 158 | + |
| 159 | + def second2frame(self, second): |
| 160 | + return int(round(second * self.framerate)) |
| 161 | + |
| 162 | + def set_current_frame(self, cur_frame_no): |
| 163 | + self.video.set(cv2.CAP_PROP_POS_FRAMES, cur_frame_no) |
| 164 | + |
| 165 | + |
| 166 | +############################################################################### |
| 167 | +# CLIP |
| 168 | +def load_clip_data(vid): |
| 169 | + try: |
| 170 | + with open("{}/{}.json".format(my_config.CLIP_PATH, vid)) as data_file: |
| 171 | + data = json.load(data_file) |
| 172 | + return data |
| 173 | + except FileNotFoundError: |
| 174 | + return None |
| 175 | + |
| 176 | + |
| 177 | +def load_clip_filtering_aux_info(vid): |
| 178 | + try: |
| 179 | + with open("{}/{}_aux_info.json".format(my_config.CLIP_PATH, vid)) as data_file: |
| 180 | + data = json.load(data_file) |
| 181 | + return data |
| 182 | + except FileNotFoundError: |
| 183 | + return None |
| 184 | + |
| 185 | + |
| 186 | +################################################################################# |
| 187 | +#SUBTITLE |
| 188 | +class SubtitleWrapper: |
| 189 | + TIMESTAMP_PATTERN = re.compile('(\d+)?:?(\d{2}):(\d{2})[.,](\d{3})') |
| 190 | + |
| 191 | + def __init__(self, vid, mode): |
| 192 | + self.subtitle = [] |
| 193 | + if mode == 'auto': |
| 194 | + self.load_auto_subtitle_data(vid) |
| 195 | + elif mode == 'gentle': |
| 196 | + self.laod_gentle_subtitle(vid) |
| 197 | + |
| 198 | + def get(self): |
| 199 | + return self.subtitle |
| 200 | + |
| 201 | + # using gentle lib |
| 202 | + def laod_gentle_subtitle(self,vid): |
| 203 | + try: |
| 204 | + with open("{}/{}_align_results.json".format(my_config.VIDEO_PATH, vid)) as data_file: |
| 205 | + data = json.load(data_file) |
| 206 | + if 'words' in data: |
| 207 | + raw_subtitle = data['words'] |
| 208 | + |
| 209 | + for word in raw_subtitle : |
| 210 | + if word['case'] == 'success': |
| 211 | + self.subtitle.append(word) |
| 212 | + else: |
| 213 | + self.subtitle = None |
| 214 | + return data |
| 215 | + except FileNotFoundError: |
| 216 | + self.subtitle = None |
| 217 | + |
| 218 | + # using youtube automatic subtitle |
| 219 | + def load_auto_subtitle_data(self, vid): |
| 220 | + lang = my_config.LANG |
| 221 | + postfix_in_filename = '-'+lang+'-auto.vtt' |
| 222 | + file_list = glob.glob(my_config.SUBTITLE_PATH + '/*' + vid + postfix_in_filename) |
| 223 | + if len(file_list) > 1: |
| 224 | + print('more than one subtitle. check this.', file_list) |
| 225 | + self.subtitle = None |
| 226 | + assert False |
| 227 | + if len(file_list) == 1: |
| 228 | + for i, subtitle_chunk in enumerate(WebVTT().read(file_list[0])): |
| 229 | + raw_subtitle = str(subtitle_chunk.raw_text) |
| 230 | + if raw_subtitle.find('\n'): |
| 231 | + raw_subtitle = raw_subtitle.split('\n') |
| 232 | + |
| 233 | + for raw_subtitle_chunk in raw_subtitle: |
| 234 | + if self.TIMESTAMP_PATTERN.search(raw_subtitle_chunk) is None: |
| 235 | + continue |
| 236 | + |
| 237 | + # removes html tags and timing tags from caption text |
| 238 | + raw_subtitle_chunk = raw_subtitle_chunk.replace("</c>", "") |
| 239 | + raw_subtitle_chunk = re.sub("<c[.]\w+>", '', raw_subtitle_chunk) |
| 240 | + |
| 241 | + word_list = [] |
| 242 | + raw_subtitle_s = subtitle_chunk.start_in_seconds |
| 243 | + raw_subtitle_e = subtitle_chunk.end_in_seconds |
| 244 | + |
| 245 | + word_chunk = raw_subtitle_chunk.split('<c>') |
| 246 | + |
| 247 | + for i, word in enumerate(word_chunk): |
| 248 | + word_info = {} |
| 249 | + |
| 250 | + if i == len(word_chunk)-1: |
| 251 | + word_info['word'] = word |
| 252 | + word_info['start'] = word_list[i-1]['end'] |
| 253 | + word_info['end'] = raw_subtitle_e |
| 254 | + word_list.append(word_info) |
| 255 | + break |
| 256 | + |
| 257 | + word = word.split("<") |
| 258 | + word_info['word'] = word[0] |
| 259 | + word_info['end'] = self.get_seconds(word[1][:-1]) |
| 260 | + |
| 261 | + if i == 0: |
| 262 | + word_info['start'] = raw_subtitle_s |
| 263 | + word_list.append(word_info) |
| 264 | + continue |
| 265 | + |
| 266 | + word_info['start'] = word_list[i-1]['end'] |
| 267 | + word_list.append(word_info) |
| 268 | + |
| 269 | + self.subtitle.extend(word_list) |
| 270 | + else: |
| 271 | + print('subtitle file is not exist') |
| 272 | + self.subtitle = None |
| 273 | + |
| 274 | + # convert timestamp to second |
| 275 | + def get_seconds(self, word_time_e): |
| 276 | + time_value = re.match(self.TIMESTAMP_PATTERN, word_time_e) |
| 277 | + if not time_value: |
| 278 | + print('wrong time stamp pattern') |
| 279 | + exit() |
| 280 | + |
| 281 | + values = list(map(lambda x: int(x) if x else 0, time_value.groups())) |
| 282 | + hours, minutes, seconds, milliseconds = values[0], values[1], values[2], values[3] |
| 283 | + |
| 284 | + return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 |
0 commit comments