Machine Learning X g boost

XGBoost - Bài 3: Xây dựng XGBoost model

XGBoost - Bài 3: Xây dựng XGBoost model

XGBoost là một thuật toán rất mạnh mẽ, tối ưu hóa về tốc độ và hiệu năng cho việc xây dựng các mô hình dự đoán. Một thống kê chỉ ra rằng, hầu hết những người chiến thắng trong các cuộc thi trên Kaggle đều sử dụng thuật toán này. Trong bài viết này, hãy cùng nhau xây dựng một mô hình XGBoost đơn giản để có thể hiểu được cách thức làm việc của nó.

Nội dung bài viết chia thành các phần:

  • Cài đặt thư viện XGBoost
  • Chuẩn bị dữ liệu
  • Train XGBoost model
  • Đánh giá XGBoost model
  • Nguồn tham khảo

1. Cài đặt thư viện XGBoost

Có 2 cách để cài đặt thư viện XGBoost. Sử dụng pip hoặc biên dịch từ mã nguồn:

1.1 Sử dụng pip để cài đặt:

pip install XGBoost

Để cập nhật thư viện, sử dụng lệnh sau:

pip install --upgrade XGBoost

1.2 Biên dịch từ mã nguồn

Sử dụng cách này nếu muốn cài đặt phiên bản mới nhất của XGBoost.

git clone --recursive https://github.com/dmlc/XGBoost
cd XGBoost
cp make/minimum.mk ./config.mk
make -j8
cd python-package
sudo python setup.py install

Tại thời điểm viết bài, phiên bản của XGBoost là 1.2

  1. Chuẩn bị dữ liệu

Trong bài viết này, chúng ta sẽ sử dụng dataset về bênh tiểu đường của Ấn Độ. Dataset bao gồm 8 features, miêu tả chi tiết tình trạng của mỗi bệnh nhân và một feature tương ứng chỉ ra bênh nhân có bị tiểu đường hay không. Chi tiết về dataset này, bạn có thể tham khảo trên UCI Machine Learning Repository website

Đây là một dataset khá đơn giản bởi vì tất cả các features của nó đều đã ở dạng số và vấn đề chỉ là “binary classification”.

6,148,72,35,0,33.6,0.627,50,1
1,85,66,29,0,26.6,0.351,31,0
8,183,64,0,0,23.3,0.672,32,1
1,89,66,23,94,28.1,0.167,21,0
0,137,40,35,168,43.1,2.288,33,1

Tải dataset và đặt nó trong thư mục làm việc hiện tại của bạn với tên là pima-indians-diabetes.csv.

Tiếp theo, load dataset từ file vừa tải về để chuẩn bị cho trainingevaluating XGBoost model.

  • Import các thư viện sử dụng:
from numpy import loadtxt
from XGBoost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
  • Load csv file
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=",")
  • Chia dataset thành dữ liệu input (X) và output (Y)
X = dataset[:, 0:8]
y = dataset[:, 8]
  • Chia X và y thành data trainingdata testing

Training data được sử dụng để train XGBoost model, trong khi testing data được sử dụng để đánh giá độ chính xác của model đó. Để làm điều này, ta có thể sử dụng hàm train_test_split() trong thư viện scikit-learn.

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_seed=42)

Đến đây, dữ liệu đã được chuẩn bị sẵn sàng cho việc train XGBoost model.

2. train XGBoost model

Thư viện XGBoost cung cấp một “Wrapper class” cho phép sử dụng XGBoost model tương tự như như làm việc với thư viện scikit-learn. XGBoost model trong thư viện XGBoost là XGBClassifier.

Tạo XGBoost model và thực hiện train:

model = XGBClassifier()
model.fit(X_train, y_train)

Ở đây, chúng ta đang sử dụng giá trị mặc định của các tham số. Mình sẽ có các bài việc về việc *tuning papameters" cho XGBoost model, mời các bạn đón đọc.

Bạn có thể quan sát các tham số sử dụng trong model bằng lệnh sau:

print(model)

3. Đánh giá XGBoost model

Để sử dụng model đã train để dự đoán, sử dụng hàm model.predict():

predictions = model.predict(X_test)

Ta có thể đánh giá độ chính xác của model bằng cách so sánh kết quả dự đoán của model với kêt quả thực tế. Hàm accuracy_score() giúp chúng ta thực hiện việc này:

accuracy = accuracy_score(y_test, predictions)
print('Accuracy: %.2f%%' % (accuracy*100))

Kết quả cuối cùng:

Accuracy: 77.95%

Kết quả khá tốt đối với bài toán này.

5. Tổng kết

Trong bài viết này, chúng ta đã xây dựng XGBoost model sử dụng thư viện XGBoost. Cụ thể, chúng ta đã học:

  • Cách cài đặt thư viện XGBoost
  • Chuẩn bị dữ liệu train model
  • Đánh giá model

Trong bài tiếp theo, chúng ta sẽ bàn luận về một số phương pháp chuẩn bị dữ liệu train cho XGBoost model.

Toàn bộ source code của bài này các bạn có thể tham khảo trên github cá nhân của mình tại github.

Bài viết có tham khảo tại tham khảo