1
0
This repository has been archived on 2025-03-06. You can view files and clone it, but cannot push or open issues or pull requests.
ResearchMethods/mini_proj/_image_classifier.py
Kelvin Davis 76e3023750 Added traditional classifiers:
svm, decision tree, gaussian naive bayes, random forest.
2018-05-22 18:14:48 +10:00

51 lines
1.7 KiB
Python

class ImageClassifier:
"""Class to create an ImageClassifier from a regular classifier with 5
methods that are common amongst classifiers.
"""
def __init__(self, clf, *args, **kwargs):
self.clf = clf(*args, **kwargs)
def fit(self, X, *args, **kwargs):
X = X.reshape((len(X), -1))
return self.clf.fit(X, *args, **kwargs)
def predict(self, X, *args, **kwargs):
X = X.reshape((len(X), -1))
return self.clf.predict(X, *args, **kwargs)
def score(self, X, *args, **kwargs):
X = X.reshape((len(X), -1))
return self.clf.score(X, *args, **kwargs)
def get_params(self, *args, **kwargs):
return self.clf.get_params(*args, **kwargs)
def set_params(self, **params):
return self.set_params(**params)
if __name__ == '__main__':
# Import datasets, classifiers and performance metrics
from sklearn import datasets, svm, metrics
# The digits dataset
digits = datasets.load_digits()
n_samples = len(digits.images)
data = digits.images
# Create a classifier: a support vector classifier
classifier = ImageClassifier(svm.SVC, gamma=0.001)
# We learn the digits on the first half of the digits
classifier.fit(data[:n_samples // 2], digits.target[:n_samples // 2])
# Now predict the value of the digit on the second half:
expected = digits.target[n_samples // 2:]
predicted = classifier.predict(data[n_samples // 2:])
print("Classification report for classifier %s:\n%s\n"
% (classifier, metrics.classification_report(expected, predicted)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))