65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
'''
|
|
Created by Tony Silvestre to prepare images for use from a Kaggle Where's Waldo dataset
|
|
'''
|
|
import os
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
import math
|
|
import cv2
|
|
|
|
def gen_data(w_path, n_w_path):
|
|
waldo_file_list = os.listdir(os.path.join(w_path))
|
|
total_w = len(waldo_file_list)
|
|
not_waldo_file_list = os.listdir(os.path.join(n_w_path))
|
|
total_nw = len(not_waldo_file_list)
|
|
imgs_raw = [] # Images
|
|
imgs_lbl = [] # Image labels
|
|
|
|
#imgs_raw = np.array([np.array(imread(wdir + "waldo/"+fname)) for fname in os.listdir(wdir + "waldo")])
|
|
i = 0
|
|
for image_name in waldo_file_list:
|
|
pic = cv2.imread(os.path.join(w_path, image_name)) # NOTE: cv2.imread() returns a numpy array in BGR not RGB
|
|
imgs_raw.append(pic)
|
|
imgs_lbl.append(1) # Value of 1 as Waldo is present in the image
|
|
|
|
print('Completed: {0}/{1} Waldo images'.format(i+1, total_w))
|
|
i += 1
|
|
|
|
i = 0
|
|
for image_name in not_waldo_file_list:
|
|
pic = cv2.imread(os.path.join(n_w_path, image_name))
|
|
imgs_raw.append(pic)
|
|
imgs_lbl.append(0)
|
|
|
|
print('Completed: {0}/{1} non-Waldo images'.format(i+1, total_nw))
|
|
i += 1
|
|
|
|
# Calculate what 30% of each set is
|
|
third_of_w = math.floor(0.3*total_w)
|
|
third_of_nw = math.floor(0.3*total_nw)
|
|
|
|
# Split data into training and test data (60%/30%)
|
|
train_data = np.append(imgs_raw[(third_of_w+1):total_w], imgs_raw[(total_w + third_of_nw + 1):len(imgs_raw)-1], axis=0)
|
|
train_lbl = np.append(imgs_lbl[(third_of_w+1):total_w], imgs_lbl[(total_w + third_of_nw + 1):len(imgs_lbl)-1], axis=0)
|
|
# If axis not given, both arrays are flattened before being appended
|
|
test_data = np.append(imgs_raw[0:third_of_w], imgs_raw[total_w:(total_w + third_of_nw)], axis=0)
|
|
test_lbl = np.append(imgs_lbl[0:third_of_w], imgs_lbl[total_w:(total_w + third_of_nw)], axis=0)
|
|
|
|
try:
|
|
# Save the data as numpy files
|
|
np.save('Waldo_train_data.npy', train_data)
|
|
np.save('Waldo_train_lbl.npy', train_lbl)
|
|
np.save('Waldo_test_data.npy', test_data)
|
|
np.save('Waldo_test_lbl.npy', test_lbl)
|
|
print("All data saved")
|
|
except:
|
|
print("ERROR: Data may not be completely saved")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Paths to the Waldo images
|
|
waldo_path = 'waldo_data/64/waldo'
|
|
n_waldo_path = 'waldo_data/64/notwaldo'
|
|
|
|
gen_data(waldo_path, n_waldo_path)
|