Skip to content

Commit 96d132c

Browse files
committed
Create data_utils.py
1 parent 61dfca1 commit 96d132c

File tree

1 file changed

+284
-0
lines changed

1 file changed

+284
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
# ------------------------------------------------------------------------------
2+
# Copyright (c) ETRI. All rights reserved.
3+
# Licensed under the BSD 3-Clause License.
4+
# This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project.
5+
# You can refer to details of AIR project at https://aiforrobots.github.io
6+
# Written by Youngwoo Yoon (youngwoo@etri.re.kr)
7+
# ------------------------------------------------------------------------------
8+
9+
import glob
10+
import matplotlib
11+
import cv2
12+
import re
13+
import json
14+
import _pickle as pickle
15+
from webvtt import WebVTT
16+
from config import my_config
17+
18+
19+
###############################################################################
20+
# SKELETON
21+
def draw_skeleton_on_image(img, skeleton, thickness=15):
22+
if not skeleton:
23+
return img
24+
25+
new_img = img.copy()
26+
for pair in SkeletonWrapper.skeleton_line_pairs:
27+
pt1 = (int(skeleton[pair[0] * 3]), int(skeleton[pair[0] * 3 + 1]))
28+
pt2 = (int(skeleton[pair[1] * 3]), int(skeleton[pair[1] * 3 + 1]))
29+
if pt1[0] == 0 or pt2[1] == 0:
30+
pass
31+
else:
32+
rgb = [v * 255 for v in matplotlib.colors.to_rgba(pair[2])][:3]
33+
cv2.line(new_img, pt1, pt2, color=rgb[::-1], thickness=thickness)
34+
35+
return new_img
36+
37+
38+
def is_list_empty(my_list):
39+
return all(map(is_list_empty, my_list)) if isinstance(my_list, list) else False
40+
41+
42+
def get_closest_skeleton(frame, selected_body):
43+
""" find the closest one to the selected skeleton """
44+
diff_idx = [i * 3 for i in range(8)] + [i * 3 + 1 for i in range(8)] # upper-body
45+
46+
min_diff = 10000000
47+
tracked_person = None
48+
for person in frame: # people
49+
body = get_skeleton_from_frame(person)
50+
51+
diff = 0
52+
n_diff = 0
53+
for i in diff_idx:
54+
if body[i] > 0 and selected_body[i] > 0:
55+
diff += abs(body[i] - selected_body[i])
56+
n_diff += 1
57+
if n_diff > 0:
58+
diff /= n_diff
59+
if diff < min_diff:
60+
min_diff = diff
61+
tracked_person = person
62+
63+
base_distance = max(abs(selected_body[0 * 3 + 1] - selected_body[1 * 3 + 1]) * 3,
64+
abs(selected_body[2 * 3] - selected_body[5 * 3]) * 2)
65+
if tracked_person and min_diff > base_distance: # tracking failed
66+
tracked_person = None
67+
68+
return tracked_person
69+
70+
71+
def get_skeleton_from_frame(frame):
72+
if 'pose_keypoints_2d' in frame:
73+
return frame['pose_keypoints_2d']
74+
elif 'pose_keypoints' in frame:
75+
return frame['pose_keypoints']
76+
else:
77+
return None
78+
79+
80+
class SkeletonWrapper:
81+
# color names: https://matplotlib.org/mpl_examples/color/named_colors.png
82+
visualization_line_pairs = [(0, 1, 'b'), (1, 2, 'darkred'), (2, 3, 'r'), (3, 4, 'gold'), (1, 5, 'darkgreen'), (5, 6, 'g'),
83+
(6, 7, 'lightgreen'),
84+
(1, 8, 'darkcyan'), (8, 9, 'c'), (9, 10, 'skyblue'), (1, 11, 'deeppink'), (11, 12, 'hotpink'), (12, 13, 'lightpink')]
85+
skeletons = []
86+
skeleton_line_pairs = [(0, 1, 'b'), (1, 2, 'darkred'), (2, 3, 'r'), (3, 4, 'gold'), (1, 5, 'darkgreen'),
87+
(5, 6, 'g'), (6, 7, 'lightgreen')]
88+
89+
def __init__(self, basepath, vid):
90+
# load skeleton data (and save it to pickle for next load)
91+
pickle_file = glob.glob(basepath + '/' + vid + '.pickle')
92+
93+
if pickle_file:
94+
with open(pickle_file[0], 'rb') as file:
95+
self.skeletons = pickle.load(file)
96+
else:
97+
files = glob.glob(basepath + '/' + vid + '/*.json')
98+
if len(files) > 10:
99+
files = sorted(files)
100+
self.skeletons = []
101+
for file in files:
102+
self.skeletons.append(self.read_skeleton_json(file))
103+
with open(basepath + '/' + vid + '.pickle', 'wb') as file:
104+
pickle.dump(self.skeletons, file)
105+
else:
106+
self.skeletons = []
107+
108+
109+
def read_skeleton_json(self, file):
110+
with open(file) as json_file:
111+
skeleton_json = json.load(json_file)
112+
return skeleton_json['people']
113+
114+
115+
def get(self, start_frame_no, end_frame_no, interval=1):
116+
117+
chunk = self.skeletons[start_frame_no:end_frame_no]
118+
119+
if is_list_empty(chunk):
120+
return []
121+
else:
122+
if interval > 1:
123+
return chunk[::int(interval)]
124+
else:
125+
return chunk
126+
127+
128+
###############################################################################
129+
# VIDEO
130+
def read_video(base_path, vid):
131+
files = glob.glob(base_path + '/*' + vid + '.mp4')
132+
if len(files) == 0:
133+
return None
134+
elif len(files) >= 2:
135+
assert False
136+
filepath = files[0]
137+
138+
video_obj = VideoWrapper(filepath)
139+
140+
return video_obj
141+
142+
143+
class VideoWrapper:
144+
video = []
145+
146+
def __init__(self, filepath):
147+
self.filepath = filepath
148+
self.video = cv2.VideoCapture(filepath)
149+
self.total_frames = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT))
150+
self.height = self.video.get(cv2.CAP_PROP_FRAME_HEIGHT)
151+
self.framerate = self.video.get(cv2.CAP_PROP_FPS)
152+
153+
def get_video_reader(self):
154+
return self.video
155+
156+
def frame2second(self, frame_no):
157+
return frame_no / self.framerate
158+
159+
def second2frame(self, second):
160+
return int(round(second * self.framerate))
161+
162+
def set_current_frame(self, cur_frame_no):
163+
self.video.set(cv2.CAP_PROP_POS_FRAMES, cur_frame_no)
164+
165+
166+
###############################################################################
167+
# CLIP
168+
def load_clip_data(vid):
169+
try:
170+
with open("{}/{}.json".format(my_config.CLIP_PATH, vid)) as data_file:
171+
data = json.load(data_file)
172+
return data
173+
except FileNotFoundError:
174+
return None
175+
176+
177+
def load_clip_filtering_aux_info(vid):
178+
try:
179+
with open("{}/{}_aux_info.json".format(my_config.CLIP_PATH, vid)) as data_file:
180+
data = json.load(data_file)
181+
return data
182+
except FileNotFoundError:
183+
return None
184+
185+
186+
#################################################################################
187+
#SUBTITLE
188+
class SubtitleWrapper:
189+
TIMESTAMP_PATTERN = re.compile('(\d+)?:?(\d{2}):(\d{2})[.,](\d{3})')
190+
191+
def __init__(self, vid, mode):
192+
self.subtitle = []
193+
if mode == 'auto':
194+
self.load_auto_subtitle_data(vid)
195+
elif mode == 'gentle':
196+
self.laod_gentle_subtitle(vid)
197+
198+
def get(self):
199+
return self.subtitle
200+
201+
# using gentle lib
202+
def laod_gentle_subtitle(self,vid):
203+
try:
204+
with open("{}/{}_align_results.json".format(my_config.VIDEO_PATH, vid)) as data_file:
205+
data = json.load(data_file)
206+
if 'words' in data:
207+
raw_subtitle = data['words']
208+
209+
for word in raw_subtitle :
210+
if word['case'] == 'success':
211+
self.subtitle.append(word)
212+
else:
213+
self.subtitle = None
214+
return data
215+
except FileNotFoundError:
216+
self.subtitle = None
217+
218+
# using youtube automatic subtitle
219+
def load_auto_subtitle_data(self, vid):
220+
lang = my_config.LANG
221+
postfix_in_filename = '-'+lang+'-auto.vtt'
222+
file_list = glob.glob(my_config.SUBTITLE_PATH + '/*' + vid + postfix_in_filename)
223+
if len(file_list) > 1:
224+
print('more than one subtitle. check this.', file_list)
225+
self.subtitle = None
226+
assert False
227+
if len(file_list) == 1:
228+
for i, subtitle_chunk in enumerate(WebVTT().read(file_list[0])):
229+
raw_subtitle = str(subtitle_chunk.raw_text)
230+
if raw_subtitle.find('\n'):
231+
raw_subtitle = raw_subtitle.split('\n')
232+
233+
for raw_subtitle_chunk in raw_subtitle:
234+
if self.TIMESTAMP_PATTERN.search(raw_subtitle_chunk) is None:
235+
continue
236+
237+
# removes html tags and timing tags from caption text
238+
raw_subtitle_chunk = raw_subtitle_chunk.replace("</c>", "")
239+
raw_subtitle_chunk = re.sub("<c[.]\w+>", '', raw_subtitle_chunk)
240+
241+
word_list = []
242+
raw_subtitle_s = subtitle_chunk.start_in_seconds
243+
raw_subtitle_e = subtitle_chunk.end_in_seconds
244+
245+
word_chunk = raw_subtitle_chunk.split('<c>')
246+
247+
for i, word in enumerate(word_chunk):
248+
word_info = {}
249+
250+
if i == len(word_chunk)-1:
251+
word_info['word'] = word
252+
word_info['start'] = word_list[i-1]['end']
253+
word_info['end'] = raw_subtitle_e
254+
word_list.append(word_info)
255+
break
256+
257+
word = word.split("<")
258+
word_info['word'] = word[0]
259+
word_info['end'] = self.get_seconds(word[1][:-1])
260+
261+
if i == 0:
262+
word_info['start'] = raw_subtitle_s
263+
word_list.append(word_info)
264+
continue
265+
266+
word_info['start'] = word_list[i-1]['end']
267+
word_list.append(word_info)
268+
269+
self.subtitle.extend(word_list)
270+
else:
271+
print('subtitle file is not exist')
272+
self.subtitle = None
273+
274+
# convert timestamp to second
275+
def get_seconds(self, word_time_e):
276+
time_value = re.match(self.TIMESTAMP_PATTERN, word_time_e)
277+
if not time_value:
278+
print('wrong time stamp pattern')
279+
exit()
280+
281+
values = list(map(lambda x: int(x) if x else 0, time_value.groups()))
282+
hours, minutes, seconds, milliseconds = values[0], values[1], values[2], values[3]
283+
284+
return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000

0 commit comments

Comments
 (0)