-
-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
134 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,134 @@ | ||
ok | ||
id: decision-tree | ||
title: Decision Tree Algorithm | ||
sidebar_label: Decision Tree | ||
description: "In this post, we'll explore the Decision Tree Algorithm, a popular machine learning model used for classification and regression tasks." | ||
tags: [machine learning, algorithms, decision tree, classification, regression] | ||
|
||
--- | ||
|
||
### Definition: | ||
The **Decision Tree Algorithm** is a supervised machine learning algorithm used for both **classification** and **regression** tasks. It splits data into subsets based on feature values, creating a tree-like structure where each internal node represents a decision based on a feature, and each leaf node represents an outcome (class label or continuous value). | ||
|
||
### Characteristics: | ||
- **Recursive Partitioning**: | ||
A decision tree recursively splits the dataset based on specific conditions (features) to form a tree structure that leads to a prediction at each leaf node. | ||
|
||
- **Interpretability**: | ||
Decision trees are highly interpretable because the decision-making process is clear and transparent. You can easily follow the path of decisions from root to leaf to understand how a prediction is made. | ||
|
||
- **Non-parametric**: | ||
This algorithm doesn’t make any assumptions about the underlying data distribution, making it suitable for a variety of datasets. | ||
|
||
### Types of Decision Trees: | ||
1. **Classification Tree**: | ||
Used when the target variable is categorical (e.g., yes/no, spam/ham). It predicts the class label based on the features. | ||
|
||
2. **Regression Tree**: | ||
Used when the target variable is continuous (e.g., predicting a price). It predicts a real number based on the features. | ||
|
||
### Steps Involved: | ||
1. **Choose the Best Split**: | ||
The dataset is split based on the feature that maximizes some criteria. For classification, **Gini Index** or **Entropy** (used in information gain) is commonly used. For regression, **mean squared error** is often the criterion. | ||
|
||
2. **Recursively Split**: | ||
The process is repeated recursively for each subset, creating branches until no further meaningful splits can be made (e.g., when the data is perfectly classified or a stopping criterion like maximum depth is reached). | ||
|
||
3. **Assign Leaf Nodes**: | ||
Once the splitting stops, the leaf nodes are assigned a class label (for classification) or a value (for regression). | ||
|
||
### Problem Statement: | ||
Given a dataset with features and target values, the goal is to build a decision tree that can **classify** data points into categories or **predict** continuous values based on the features. | ||
|
||
### Key Concepts: | ||
- **Root Node**: | ||
The top-most node in the tree, where the first feature split occurs. | ||
|
||
- **Internal Nodes**: | ||
Nodes where further feature splits occur based on the conditions. | ||
|
||
- **Leaf Nodes**: | ||
Terminal nodes that hold the final prediction (class or value). | ||
|
||
### Split Criteria: | ||
- **Gini Impurity** (for classification): | ||
Measures the probability of a randomly chosen element being misclassified. | ||
|
||
![image](https://github.com/user-attachments/assets/505d8d21-3f41-4bba-9120-4ad98beec8c1) | ||
|
||
|
||
Where \( p_i \) is the probability of the class label at node \( i \). | ||
|
||
- **Entropy** (for classification): | ||
Measures the disorder or impurity in the dataset. | ||
|
||
![image](https://github.com/user-attachments/assets/954fa9ab-f3d6-4055-8c12-7630dd8e6a50) | ||
|
||
|
||
Where \( p_i \) is the proportion of data points in class \( i \). | ||
|
||
- **Information Gain**: | ||
Measures the reduction in entropy after a dataset is split on a feature. | ||
|
||
![image](https://github.com/user-attachments/assets/e0eecef3-6e4d-414c-a963-19708922ed91) | ||
|
||
|
||
- **Mean Squared Error (MSE)** (for regression): | ||
Measures the variance between the predicted and actual values at each node. | ||
|
||
### Time Complexity: | ||
- **Best, Average, and Worst Case: $O(n \log n)$** | ||
The time complexity depends on sorting the dataset at each node. For `n` data points, the overall complexity is logarithmic in depth with respect to the size of the dataset. | ||
|
||
### Space Complexity: | ||
- **Space Complexity: $O(n)$** | ||
The space complexity arises from storing the tree structure and the data used for training. | ||
|
||
### Example: | ||
Consider a dataset for predicting whether someone buys a product based on their **age** and **income**: | ||
|
||
- Dataset: | ||
``` | ||
| Age | Income | Buys Product | | ||
|-------|---------|--------------| | ||
| 25 | High | Yes | | ||
| 45 | Medium | No | | ||
| 35 | Low | Yes | | ||
| 22 | Low | Yes | | ||
``` | ||
|
||
Step-by-Step Execution: | ||
|
||
1. **Choose the best split**: | ||
Using a criterion like Gini Impurity or Information Gain, the algorithm will decide which feature (Age or Income) provides the best split. | ||
|
||
2. **Recursively split**: | ||
It splits the data and continues until leaf nodes are formed, which represent the final decision (Yes/No for classification). | ||
|
||
### Python Implementation: | ||
Here is a basic implementation of a Decision Tree in Python using **scikit-learn**: | ||
|
||
```python | ||
from sklearn.datasets import load_iris | ||
from sklearn.tree import DecisionTreeClassifier | ||
from sklearn import tree | ||
import matplotlib.pyplot as plt | ||
|
||
# Load dataset | ||
iris = load_iris() | ||
X, y = iris.data, iris.target | ||
|
||
# Create Decision Tree classifier | ||
clf = DecisionTreeClassifier() | ||
|
||
# Train model | ||
clf = clf.fit(X, y) | ||
|
||
# Plot the tree | ||
plt.figure(figsize=(12,8)) | ||
tree.plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names) | ||
plt.show() | ||
``` | ||
|
||
### Summary: | ||
The **Decision Tree Algorithm** is a simple yet powerful model that can be used for both classification and regression tasks. Its ease of interpretation, ability to handle both categorical and numerical data, and non-parametric nature make it a versatile choice for many machine learning problems. However, decision trees are prone to **overfitting**, which can be mitigated by techniques like **pruning**, **random forests**, or **ensemble methods**. | ||
|