Khi giải quyết một bài toán AI, rất hiếm khi số lượng model được huấn luyện và đưa vào sử dụng dừng lại ở con số 1. Bởi vì theo thời gian, dữ liệu thay đổi, yêu cầu thay đổi, … dẫn đến việc chúng ta cần cập nhật lên model mới hơn. Trong trường hợp đó, làm thế nào để quản lý được tất cả các models đó một các hợp lý, đảm bảo sử dụng đúng model mong muốn để thực hiện inference và không làm gián đoạn quá trình inference đang chạy? Đó chính là câu chuyện của Model Registry.
Trong bài hôm nay, chúng ta sẽ cùng nhau implement một Model Registry đơn giản, sử dụng SQLite. Bạn cũng sử dụng bất kỳ database nào bạn muốn.
1. Model Registry Database
Đầu tiên, hãy tạo một database, registry.db
, như sau:
import sqlite3
conn = sqlite3.connect('registry.db')
Đối tượng conn
tạo ra một kết nối đến registry.db
. Chúng ta sẽ sử dụng nó để thực thi các câu lệnh truy vấn sql.
Tiếp theo, tạo bảng model_registry
bao gồm các trường thông tin của model.
cur = conn.cursor()
cur.execute("""
CREATE TABLE model_registry (
id INTEGER PRIMARY KEY ASC,
name TEXT UNIQUE NOT NULL,
version TEXT NOT NULL,
registered_date TEXT DEFAULT CURRENT_TIMESTAMP,
metrics TEXT NOT NULL,
remote_path TEXT NOT NULL,
stage TEXT DEFAULT 'DEVELOPMENT' NOT NULL
);
""")
cur.close()
2. Xây dựng Model Registry API
Mục đích của việc xây dựng các API là làm đơn giản hóa quá trình thao tác với database. Tất cả các công việc chung một hành động sẽ được gom vào thành một API.
Chúng ta sẽ xây dựng các API sau:
Dưới đây là implement các API:
import panda as pd
class ModelRegistry:
def __init__(self, conn, table_name='model_registry'):
self.conn = conn
self.table_name = table_name
def _insert(self, values):
query = """
INSERT INTO {}
(name, version, metrics, remote_path)
VALUES (?, ?, ?, ?)""".format(self.table_name)
self._query(query, values)
def _query(self, query, values=None):
cur = self.conn.cursor()
cur.execute(query, values)
cur.close()
def publish_model(self, model_name, model_metrics):
model_version_query = """
SELECT version
FROM {}
WHERE name = '{}'
ORDER BY registered_date DESC
LIMIT 1
;""".format(self.table_name, model_name)
model_version = pd.read_sql_query(model_version_query, conn)
if model_version is not None:
model_version = int(version.iloc[0]['version'])
model_version = model_version + 1
# Assume that trained models are stored on S3
model_path = 's3://models/{}::v{}'.format(model_name, model_version)
self._insert((model_name, model_version, model_metrics, model_path))
def update_stage(self, model_name, model_version, model_stage):
query = """
UPDATE {}
SET stage = ?
WHERE name = ? AND version = ?
;""".format(self.table_name)
self._query(query, (model_stage, model_name, model_version))
def get_production_model(self, model_name):
query = """
SELECT *
FROM {}
WHERE name = '{}' AND stage = 'PRODUCTION'
;""".format(self.table_name, model_name)
return pd.read_sql_query(query, self.conn)
Code implement API khá đơn giản, hi vọng bạn có thể hiểu được dễ dàng, :)
3. Sử dụng Model Registry API
Trên thực tế , Training và Inference là 2 quá trình cùng chạy đồng thời và Model Registry cung cấp cơ chế trao đổi thông tin giữa 2 quá trình này thông qua database.
Giả sử rằng chúng ta đã trained xong model thoả mãn yêu cầu đề bài, giờ là lúc ta sử dụng Model Registry API.
Chú ý: Sau mỗi đoạn code ví dụ, ta sẽ sử dụng câu truy vấn sau đây để kiểm tra kết quả:
pd.read_sql_query("SELECT * FROM model_registry;", conn)
3.1 Model Training
conn = sqlite3.connect('registry.db')
model_registry = ModelRegistry(conn=conn)
model = None # This would be replaced by the trained model.
name = 'house_price_prediction'
metrics = {'accuracy': 0.87}
model_registry.publish_model(model=model, name=name, metrics=metrics)
id | name | version | registered_data | remote_path | stage |
---|---|---|---|---|---|
1 | house_price_prediction | 1 | 2021-01-10 12:42:25 | s3://models/house_price_prediction::v1 | DEVELOPMENT |
model = None # This would be replaced by the trained model.
name = 'house_price_prediction'
metrics = {'accuracy': 0.89}
model_registry.publish_model(model=model, name=name, metrics=metrics)
id | name | version | registered_data | remote_path | stage |
---|---|---|---|---|---|
1 | house_price_prediction | 1 | 2020-07-12 12:45:27 | s3://models/house_price_prediction::v1 | DEVELOPMENT |
2 | house_price_prediction | 2 | 2021-01-10 12:42:25 | s3://models/house_price_prediction::v2 | DEVELOPMENT |
3.2 Chuyển model sang trạng thái sẵn sàng sử dụng cho sản phẩm thực tế
model_registry.update_stage(name=name, version='2', stage="PRODUCTION")
id | name | version | registered_data | remote_path | stage |
---|---|---|---|---|---|
1 | house_price_prediction | 1 | 2020-07-12 12:45:27 | s3://models/house_price_prediction::v1 | DEVELOPMENT |
2 | house_price_prediction | 2 | 2021-01-10 12:42:25 | s3://models/house_price_prediction::v2 | PRODUCTION |
3.3 Lấy thông tin model
model_registry.get_production_model(name=name)
id | name | version | registered_data | remote_path | stage |
---|---|---|---|---|---|
2 | house_price_prediction | 2 | 2021-01-10 12:42:25 | s3://models/house_price_prediction::v2 | PRODUCTION |
4. Kết luận
Như vậy là chúng ta đã implemented xong Model Register, sử dụng SQLite database. Bạn hoàn toàn có thể áp dụng những gì được trình bày trong bài viết này vào trong dự án của bạn.
Hiện nay cũng có một số open-source giúp bạn thực hiện việc này một cách trực quan hơn. Nổi bật trong số đó là MLflow. Mình sẽ có một bài viết hướng dẫn sử dụng MLflow cho Model Registry trong tương lai.
Bài viết tiếp theo, mình sẽ thảo luận về vấn đề Retraining model. Mời các bạn đón đọc!
5. Tham khảo