130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
'''
|
|
Created by Tony Silvestre to prepare images for use from a Kaggle Where's Waldo dataset
|
|
'''
|
|
import os
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
import math
|
|
import cv2
|
|
from skimage import color, exposure, transform
|
|
import skimage as sk
|
|
import random
|
|
|
|
'''
|
|
Defines some basic preprocessing for the images
|
|
'''
|
|
def preprocess(img):
|
|
# Histogram normalization in v channel
|
|
hsv = color.rgb2hsv(img)
|
|
hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2])
|
|
img = color.hsv2rgb(hsv)
|
|
|
|
return img
|
|
|
|
'''
|
|
Augments an image (horizontal reflection, hue, contrast, saturation adjustment) to expand on the dataset
|
|
Code was adapted from Medium.com/@thimblot (Thomas Himblot)
|
|
'''
|
|
def augment(img):
|
|
# Randomly rotate the image
|
|
random_degree = random.uniform(-25, 25) # Random degree of rotation between 25% on the left and 25% on the right
|
|
img = sk.transform.rotate(img, random_degree)
|
|
|
|
# Add random noise to the image
|
|
img = sk.util.random_noise(img)
|
|
|
|
# Randomly (25% chance) flips the image horizontally (by inverting the image dimensions)
|
|
if random.randint(0, 4) == 1:
|
|
img = img[:, ::-1]
|
|
|
|
# Randomly (25% chance) inverts the intensity range of the image
|
|
if random.randint(0, 4) == 1:
|
|
img = sk.util.invert(img)
|
|
|
|
return img
|
|
|
|
def gen_data(w_path, n_w_path):
|
|
waldo_file_list = os.listdir(os.path.join(w_path))
|
|
total_w = len(waldo_file_list)
|
|
not_waldo_file_list = os.listdir(os.path.join(n_w_path))
|
|
total_nw = len(not_waldo_file_list)
|
|
imgs_raw = [] # Images
|
|
imgs_lbl = [] # Image labels
|
|
|
|
w = 0
|
|
for image_name in waldo_file_list:
|
|
pic = cv2.imread(os.path.join(w_path, image_name)) # NOTE: cv2.imread() returns a numpy array in BGR not RGB
|
|
pic = preprocess(pic)
|
|
|
|
pic_roll = np.rollaxis(pic, -1) # rolls colour axis to 0
|
|
imgs_raw.append(pic_roll)
|
|
imgs_lbl.append(1) # Value of 1 as Waldo is present in the image
|
|
|
|
for j in range(0, 10):
|
|
pic = augment(pic)
|
|
pic_roll = np.rollaxis(pic, -1) # rolls colour axis to 0
|
|
imgs_raw.append(pic_roll)
|
|
imgs_lbl.append(1) # Value of 1 as Waldo is still present in the transformed image
|
|
|
|
print('Completed: {0}/{1} Waldo images'.format(w+1, total_w))
|
|
w += 1
|
|
|
|
nw = 0
|
|
for image_name in not_waldo_file_list:
|
|
pic = cv2.imread(os.path.join(n_w_path, image_name))
|
|
pic = preprocess(pic)
|
|
|
|
pic = np.rollaxis(pic, -1) # rolls colour axis to 0
|
|
imgs_raw.append(pic)
|
|
imgs_lbl.append(0)
|
|
|
|
print('Completed: {0}/{1} non-Waldo images'.format(nw+1, total_nw))
|
|
if nw > 10*w:
|
|
print("Non_Waldo files restricted")
|
|
break
|
|
else:
|
|
nw += 1
|
|
|
|
## Randomise and split data into training and test sets
|
|
# Code was modified from code written by: Kyle O'Brien (medium.com/@kylepob61392)
|
|
n_images = len(imgs_raw)
|
|
TRAIN_TEST_SPLIT = 0.75 # Amount of trainingdata as a percentage of the total
|
|
|
|
# Split at the given index
|
|
split_index = int(TRAIN_TEST_SPLIT * n_images)
|
|
shuffled_indices = np.random.permutation(n_images)
|
|
train_indices = shuffled_indices[0:split_index]
|
|
test_indices = shuffled_indices[split_index:]
|
|
|
|
train_data = []
|
|
train_lbl = []
|
|
test_data = []
|
|
test_lbl = []
|
|
|
|
# Split the images and the labels
|
|
for index in train_indices:
|
|
train_data.append(imgs_raw[index])
|
|
train_lbl.append(imgs_lbl[index])
|
|
|
|
for index in test_indices:
|
|
test_data.append(imgs_raw[index])
|
|
test_lbl.append(imgs_lbl[index])
|
|
|
|
try:
|
|
# Save the data as numpy files
|
|
np.save('Waldo_train_data.npy', train_data)
|
|
np.save('Waldo_train_lbl.npy', train_lbl)
|
|
np.save('Waldo_test_data.npy', test_data)
|
|
np.save('Waldo_test_lbl.npy', test_lbl)
|
|
print("All data saved")
|
|
except:
|
|
print("ERROR: Data may not be completely saved")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Paths to the Waldo images
|
|
waldo_path = 'waldo_data/64/waldo'
|
|
n_waldo_path = 'waldo_data/64/notwaldo'
|
|
|
|
gen_data(waldo_path, n_waldo_path)
|