diff --git a/mini_proj/waldo_model.py b/mini_proj/waldo_model.py index 08160e0..c29b8c4 100644 --- a/mini_proj/waldo_model.py +++ b/mini_proj/waldo_model.py @@ -38,12 +38,12 @@ def FCN(): conv3 = Conv2D(32, (3, 3), activation='relu', padding='same')(m_pool2) # drop2 = Dropout(0.2)(conv3) # Drop some portion of features to prevent overfitting # m_pool2 = MaxPooling2D(pool_size=(2, 2))(drop2) - + # conv4 = Conv2D(64, (2, 2), activation='relu', padding='same')(m_pool2) flat = Flatten()(conv3) # Makes data 1D dense = Dense(64, activation='relu')(flat) # Fully connected layer - drop3 = Dropout(0.2)(dense) + drop3 = Dropout(0.2)(dense) classif = Dense(2, activation='sigmoid')(drop3) # Final layer to classify ## Define the model structure @@ -53,6 +53,22 @@ def FCN(): return model +def precision(y_true, y_pred): + y_pred = np.round(y_pred) + num = np.sum(np.logical_and(y_true, y_pred)) + den = np.sum(y_pred) + return np.divide(num, den) + +def recall(y_true, y_pred): + y_pred = np.round(y_pred) + num = np.sum(np.logical_and(y_true, y_pred)) + den = np.sum(y_true) + return np.divide(num, den) + +def f_measure(y_true, y_pred): + p = precision(y_true, y_pred) + r = recall(y_true, y_pred) + return 2 * p * r / (p + r) ## Open data im_train = np.load('Waldo_train_data.npy') @@ -71,7 +87,7 @@ model = FCN() # ensemble_iclf = ImageClassifier(ensemble.RandomForestClassifier) ## Define training parameters -epochs = 20 # an epoch is one forward pass and back propogation of all training data +epochs = 20 # an epoch is one forward pass and back propogation of all training data batch_size = 150 # batch size - number of training example used in one forward/backward pass # (higher batch size uses more memory, smaller batch size takes more time) #lrate = 0.01 # Learning rate of the model - controls magnitude of weight changes in training the NN @@ -131,4 +147,3 @@ accuracy = accuracy_score(lbl_test, pred_lbl) print("Accuracy: " + str(accuracy)) print("Images generated in {} seconds".format(end - start)) np.save('predicted_results.npy', pred_lbl) -