1
0
This commit is contained in:
Kelvin Davis 2018-05-23 23:17:55 +10:00
commit 16b9365a4d
2 changed files with 22 additions and 5 deletions

View File

@ -6,6 +6,21 @@ import numpy as np
from matplotlib import pyplot as plt
import math
import cv2
from skimage import color, exposure, transform
'''
Defines some basic preprocessing for the images
'''
def preprocess(img):
# Histogram normalization in v channel
hsv = color.rgb2hsv(img)
hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2])
img = color.hsv2rgb(hsv)
img = img/255 # Scaling images down to values of 0-255
img = np.rollaxis(img, -1) # rolls colour axis to 0
return img
def gen_data(w_path, n_w_path):
waldo_file_list = os.listdir(os.path.join(w_path))
@ -15,12 +30,12 @@ def gen_data(w_path, n_w_path):
imgs_raw = [] # Images
imgs_lbl = [] # Image labels
#imgs_raw = np.array([np.array(imread(wdir + "waldo/"+fname)) for fname in os.listdir(wdir + "waldo")])
i = 0
for image_name in waldo_file_list:
pic = cv2.imread(os.path.join(w_path, image_name)) # NOTE: cv2.imread() returns a numpy array in BGR not RGB
pic = pic/255 # Scaling images down to values of 0-255
imgs_raw.append(np.rollaxis(pic, -1)) # rolls colour axis to 0
pic = preprocess(pic)
imgs_raw.append(pic)
imgs_lbl.append(1) # Value of 1 as Waldo is present in the image
print('Completed: {0}/{1} Waldo images'.format(i+1, total_w))
@ -29,8 +44,9 @@ def gen_data(w_path, n_w_path):
i = 0
for image_name in not_waldo_file_list:
pic = cv2.imread(os.path.join(n_w_path, image_name))
pic = pic/255 # Scaling images down to values of 0-255
imgs_raw.append(np.rollaxis(pic, -1))
pic = preprocess(pic)
imgs_raw.append(pic)
imgs_lbl.append(0)
print('Completed: {0}/{1} non-Waldo images'.format(i+1, total_nw))

View File

@ -8,6 +8,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from keras.layers import Dense, Dropout, Activation, Flatten, Input
from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D
from keras.models import Model
from keras import metrics
from sklearn import svm, tree, naive_bayes, ensemble
from sklearn.metrics import accuracy_score