Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
b235fc7e63
56
mini_proj/traditionals.py
Normal file
56
mini_proj/traditionals.py
Normal file
@ -0,0 +1,56 @@
|
||||
import numpy as np
|
||||
import time as t
|
||||
from sklearn import svm, ensemble, naive_bayes, neighbors
|
||||
from _image_classifier import ImageClassifier
|
||||
|
||||
def precision(y_true, y_pred):
|
||||
y_pred = np.round(y_pred)
|
||||
num = np.sum(np.logical_and(y_true, y_pred))
|
||||
den = np.sum(y_pred)
|
||||
return np.divide(num, den)
|
||||
|
||||
def recall(y_true, y_pred):
|
||||
y_pred = np.round(y_pred)
|
||||
num = np.sum(np.logical_and(y_true, y_pred))
|
||||
den = np.sum(y_true)
|
||||
return np.divide(num, den)
|
||||
|
||||
def f_measure(y_true, y_pred):
|
||||
p = precision(y_true, y_pred)
|
||||
r = recall(y_true, y_pred)
|
||||
return 2 * p * r / (p + r)
|
||||
|
||||
def metric_test(iclf, metric, test_X, test_Y):
|
||||
return metric(test_Y, iclf.predict(test_X))
|
||||
|
||||
## Open data
|
||||
im_train = np.load('Waldo_train_data.npy')
|
||||
im_test = np.load('Waldo_test_data.npy')
|
||||
|
||||
lbl_train = np.load('Waldo_train_lbl.npy')
|
||||
lbl_test = np.load('Waldo_test_lbl.npy')
|
||||
|
||||
# lbl_train = to_categorical(lbl_train) # One hot encoding the labels
|
||||
# lbl_test = to_categorical(lbl_test)
|
||||
|
||||
my_metric_test = lambda iclf, f: metric_test(iclf, f, im_test, lbl_test)
|
||||
|
||||
# ## Define model
|
||||
svm_iclf = ImageClassifier(svm.SVC)
|
||||
tree_iclf = ImageClassifier(neighbors.KNeighborsClassifier)
|
||||
naive_bayes_iclf = ImageClassifier(naive_bayes.GaussianNB)
|
||||
ensemble_iclf = ImageClassifier(ensemble.RandomForestClassifier)
|
||||
|
||||
classifiers = [
|
||||
svm_iclf,
|
||||
tree_iclf,
|
||||
naive_bayes_iclf,
|
||||
ensemble_iclf,
|
||||
]
|
||||
|
||||
for clf in classifiers:
|
||||
start = t.time() # Records time before training
|
||||
clf.fit(im_train, lbl_train)
|
||||
end = t.time() # Records time after tranining
|
||||
print("training time:", end-start)
|
||||
print(clf.score(im_test, lbl_test))
|
Reference in New Issue
Block a user