Hẳn các bạn đã biết, trong hầu hết các bài toán AI, chúng ta không chỉ train model 1 lần rồi thôi (mình không nói đến việc thử-sai trong quá trình tuning model). Tại thời điểm này, model hoạt động tốt đúng như những gì ta mong đợi, nhưng sau một thời gian, hiệu năng của model có thể giảm xuống. Đó là một trong những dấu hiệu chỉ ra rằng ta phải retrain lại model. Trong bài hôm nay, mình sẽ cùng các bạn tìm hiểu chi tiết hơn về vấn đề này.
1. Model Drift
Model Drift là khái niệm mô tả hiện tượng hiệu năng dự đoán của model suy giảm theo thời gian do có sự thay đổi của môi trường làm sai lệch các giả thiết ban đầu của model. Thuật ngữ Model Drift
(model chuyển dịch) có thể khiến chúng ta hơi bối rối 1 chút, vì bản chất là model không thay đổi, chỉ có các yếu tố môi trường bên ngoài thay đổi, input data thay đổi.
2. Làm sao để nhận biệt hiện tượng Model Drift
2.1 Kiểm tra độ chính xác của model
Biểu hiên trực tiếp và rõ ràng nhất của Model Drift là độ chính xác dự đoán (độ chính xác ở đây dùng chung cho tất cả các metrics đánh giá model) giảm dần theo thời gian. Nhưng việc giám sát việc này không phải lúc nào cũng đơn giản bởi vì ta phải có cả kết quả dự đoán của model và ground truth
, đặc biệt khi model đang chạy trong sản phầm thực tế (môi trường production hay online).
Có một cách đơn giản hơn để kiểm tra độ chính xác của model có bị suy giảm hay không, đó là offline monitor
. Cách này được thực hiện trước khi model triển khai model vào môi trường production. Giả sử ra có dữ liệu từ 01/2019 đến 01/2021. Ta sẽ sử dụng dữ liệu từ 01/2019 đến 06/2020 đê train và đánh giá model, sau đó sử dụng model này để dự đoán trên dữ liệu tháng 07/2020 đến 01/2021. Kết quả dự đoán được lưu lại để đánh giá xem độ chính xác của model có suy giảm hay không, nếu có thì mức độ suy giảm như thế nào? … Sử dụng cách này cho phép chúng ta ước lượng được tốc độ suy giảm độ chính xác, từ đó lên kế hoạch retrain lại model.
2.2 Kiểm tra phân bố của dữ liệu
Nếu phân bố của dữ liệu mới có sự sai khác so với dữ liệu huấn luyện model từ ban đầu thì độ chính xác của model cũng sẽ giảm. Vì thế, đây cũng là một dấu hiệu nhận biết sớm của hiện tượng Model Drift.
Để đánh giá sự phân bố của dữ liệu, có thể dựa vào các yếu tố sau:
Facets là một công cụ cho phép chúng ta nhanh chóng nhận ra sự thay đổi trong phân bố dữ liệu dựa trên sự quan sát các đồ thị phân bố trên dashboards. Việc theo dõi này có thể được thực hiện một cách tự động và nó sẽ gửi thống báo cho chúng ta khi sự phân bố dữ liệu thay đổi vượt quá một ngưỡng nào đó.
2.3 Kiểm tra sự tương quan giữa các features trong dữ liệu
Mối qua hệ giữa các features cũng ảnh hướng đến độ chính xác của model. Vì vậy, kiểm tra sự tương quan giữa các features từng đôi một xem chúng thay đổi ra sao cũng là một cách để nhận biết Model Drift.
3. Hiểu đúng về Model Retraining
Chúng ta đều hiểu rằng Model Retraining tức là training lại model, tạo ra model mới tốt hơn model cũ. Nhưng nếu chỉ chung chung như thế thì có rất nhiều cách để retraining model:
Giữa những cách retraining model kể trên, đâu là cách đúng nhất để loại bỏ hiện tượng Model Drift?
Quay lại khái niệm của Model Drift, đó là hiện tượng độ chính xác của model suy giảm do có sự thay đổi trong phân phối dữ liệu
. Vậy ta chỉ cần train lại model trên tập dữ liệu mới và giữ nguyên tất cả những cái khác: hyper-parameters, thuật toán, features, … Hiểu một cách đơn giản hơn thì tức là ta sẽ không thay đổi dòng code nào cả, chỉ thay đổi nội dung của file chứa dữ liệu mới để train model.
Nói vậy, không có nghĩa là chúng ta bỏ qua hoàn toàn các cách retraing model khác. Nếu bạn có đủ thời gian, công sức, bạn hay các thành viên trong dự án của bạn hoàn toàn có thể thử nghiệm cách retraining model kể trên. Sau đó sử dụng chiến lược A/B Test để đánh giá các models dựa trên các tiêu chí của bài toán. Model nào cho cho kết quả tốt hơn thì sẽ được sử dụng trong môi trường production.
4. Tần suất Retrain Model
Một vấn đề tiếp theo cần quan tâm là tần suất retrain model như thế nào là hợp lý?
Câu trả lời là không có một quy định, quy tắc cụ thể nào cả. Tùy từng bài toán mà ta có cách xử lý khác nhau.
Đối với cách thứ 2&3, cần phải có một hạ tầng độc để giám sát và đưa ra cảnh báo khi sự thay đổi đạt đến mức quy định. Việc chọn ngưỡng cho các metrics cũng cần phải xem xét cẩn thận. Ngưỡng quá thấp sẽ làm cho tần suất retrain model thường xuyên hơn, dẫn đến tốn kém chi phí tính toán (đặc biệt quan trong trường trường hợp sử dụng tài nguyên trên cloud). Ngưỡng quá cao làm cho model không thay đổi kịp với sự thay đổi của môi trường, dẫn đến không tối ưu hóa lợi nhuận, …
Đặc biệt trong trường hợp model cần thay đổi realtime mỗi khi có bất cứ dữ liệu mới (VD model dự đoán giao dịch ngân hàng an toàn hay không an toàn) thì nên sử dụng phương pháp học tăng dần, Incremental Learning / Online Learning. Phương pháp này khác các cách Retrain Model đã đề cập ở chỗ model được retrain (cập nhật) chỉ sử dụng dữ liệu mới, không phải retrain trên toàn bộ dữ liệu.
5. Chạy Retrain Model tự động
Cách cấu hình để Retrain Model tự động liên quan đến tần suất retrain model của bạn.
Nếu model được retrained định kỳ, chúng ta có thể sử dụng Kubernetes CronJobs hoặc Jenkins để lập lịch cho model chạy retrain.
Nếu model đươc retrained dựa vào trigger khi các metrics thay đổi đến ngưỡng được phát hiện, chúng ta có thể sử dụng Kubernetes Jobs hoặc Jenkins để làm việc này.
Cuối cùng, nếu model cần retrain realtime, sử dụng phương pháp Online Learning
. River là thư viện lý tưởng cho việc này. Tên cũ của nó là Creme.
6. Implement code prototype
6.1 Query Data by Date Range function
Bởi vì quá trình retraining dựa trên dữ liệu mới, nên chúng ta cần 1 hàm lấy ra những dữ liệu đó, theo 1 khoảng thời gian quy đinh. Dữ liệu mới có thể được lưu ở SQL database, S3, local storage, …
def get_raw_data(end_date, date_window=365):
'''
Retrieve all data in date range (end_date - date_window, end_date)
'''
Trong đó:
Để nhận dữ liệu mới cho việc retraining model, chúng ta sẽ gọi:
from datetime import date
training_data = get_new_data(date.today())
6.2 Generate a Machine Learning Model function
Hàm này chịu trách nhiệm train AI model: chia dataset thành tập train và tập test, trích xuất vector đặc trừng từ dữ liệu, thực hiện tuning hyper-parameters, huấn luyện model, đánh giá model, …
find_optimal_model(data, ...):
'''
Split data, generate features, tune hyper-parameters, train model, ...
'''
Tham số data
là dữ liệu để huấn luyện mode. Kết quả thực thi của hàm sẽ trả về model đã được trained và các training metrics.
6.3 Store Trained Model
Một khi model được trained xong, ta cần lưu nó lại để sử dụng về sau. Cách đơn giản nhất là sử dụng thư viện pickle
có sẵn của python. Ngoài ta, bạn cũng có thể sử dụng ONNX hoặc PMML.
# Serialize and store model on local storage
def serialize_model(training_arfifacts):
'''
Return a local path to serialized model
'''
6.4 Registry Model
Tham khảo bài Model Registry
6.5 Model Retraining Enpoint
Tập hợp tất cả các hàm lại trong một script để đơn giản hóa quy trình, retrain.py
. Script sẽ chấp nhận một tham số từ command line là end_date
, nhận về dữ liệu mới, train mode, store model và registry model.
from datetime import date
import sys
def retrain(end_date):
'''Model retraining loop.'''
data = get_raw_data(end_date)
training_artifacts = find_optimal_model(data, ...)
local_path = serialize_model(training_artifacts)
model_registry(local_path, training_artifacts)
if __name__ == '__main__':
retrain(sys.argv[1])
6.6 Scheduling the Retraining Procedure
Script retrain.py
lại tiếp tục được đóng gói trong bash script, retrain.sh
:
today_date='date +”%m/%d/%Y”'
python retrain.py $today_date
Để trigger event gọi đến bash script này, chúng ta có thể lập lịch, sử dụng một trong các công cụ sau:
Các công cụ này đều hỗ trợ đầy đủ việc xử lý ngoại lệ, cơ chế retry, … Tuy nhiên, viêc thiết lập và cài đặt sẽ tương đối mất thời gian nếu bạn chưa quen thuộc với chúng.
6.7 Retrieve the Model at Inference Time
Tham khảo bài Model Registry
7. Kết luận
Bài này, chúng ta đã bàn rất nhiều về Model Retraining. Hi vọng là bạn đã hiểu được phần nào tất cả các khía cạnh của nó để xem xét áp dụng vào dự án của bạn.
Bài viết tiếp theo, chúng ta sẽ cùng tìm hiểu về Kubernetes và áp dụng nó cho các bài toàn AI. Mời các bạn đón đọc!
8. Tham khảo