Skip to content

Commit

Permalink
get feature importance function for decision tree
Browse files Browse the repository at this point in the history
  • Loading branch information
tigist13 committed Sep 3, 2022
1 parent d75ba76 commit 49988c8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
47 changes: 46 additions & 1 deletion scripts/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,49 @@ def train_model(self, folds=1):
loss_arr.append(loss)


return self.clf, accuracy_arr, loss_arr
return self.clf, accuracy_arr, loss_arr

def test_model(self):

y_pred = self.clf.predict(self.X_test)

accuracy = self.calculate_score(y_pred, self.y_test)
self.__printAccuracy(accuracy, label="Test")

report = self.report(y_pred, self.y_test)
matrix = self.confusion_matrix(y_pred, self.y_test)

loss = calculate_loss_function(self.y_test, y_pred)

return accuracy, loss, report, matrix

def __printLoss(self, loss, step=1, label=""):
print(f"step {step}: {label} Loss of DecisionTreesModel is: {loss:.3f}")

def calculate_score(self, pred, actual):
return metrics.accuracy_score(actual, pred)

def __printAccuracy(self, acc, step=1, label=""):
print(f"step {step}: {label} Accuracy of DecisionTreesModel is: {acc:.3f}")

def report_outcome(self, pred, actual):
print("Test Metrics")
print("================")
print(metrics.classification_report(pred, actual))
return metrics.classification_report(pred, actual)

def get_feature_importance(self):
importance = self.clf.feature_importances_
featureimportance_df = pd.DataFrame()

featureimportance_df['feature'] = self.X_train.columns.to_list()
featureimportance_df['feature_importances'] = importance

return featureimportance_df

def confusion_matrix(self, pred, actual):
ax=sns.heatmap(pd.DataFrame(metrics.confusion_matrix(pred, actual)))
plt.title('Confusion Matrix')
plt.ylabel('Actual')
plt.xlabel('Predicted')
return metrics.confusion_matrix(pred, actual)
2 changes: 1 addition & 1 deletion scripts/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __printAccuracy(self, acc, step=1, label=""):
print(f"step {step}: {label} Accuracy of LogesticRegression: {acc:.3f}")


def report_output(self, pred, actual):
def report_outcome(self, pred, actual):
print("Test Metrics")
print("================")
print(metrics.classification_report(pred, actual))
Expand Down

0 comments on commit 49988c8

Please sign in to comment.