Skip to content

Commit

Permalink
Adding support for Sklearn linear regression in EvaDB (#1162)
Browse files Browse the repository at this point in the history
Starting this PR for integrating sklearn linear regression for EvaDb.

---------

Co-authored-by: Jineet Desai <[email protected]>
  • Loading branch information
jineetd and Jineet Desai authored Sep 22, 2023
1 parent 56c933a commit 0de4cd2
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 0 deletions.
11 changes: 11 additions & 0 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement):
outputs.append(column)
else:
inputs.append(column)
elif string_comparison_case_insensitive(node.function_type, "sklearn"):
assert (
"predict" in arg_map
), f"Creating {node.function_type} functions expects 'predict' metadata."
# We only support a single predict column for now
predict_columns = set([arg_map["predict"]])
for column in all_column_list:
if column.name in predict_columns:
outputs.append(column)
else:
inputs.append(column)
elif string_comparison_case_insensitive(node.function_type, "forecasting"):
# Forecasting models have only one input column which is horizon
inputs = [ColumnDefinition("horizon", ColumnType.INTEGER, None, None)]
Expand Down
53 changes: 53 additions & 0 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
string_comparison_case_insensitive,
try_to_import_forecast,
try_to_import_ludwig,
try_to_import_sklearn,
try_to_import_torch,
try_to_import_ultralytics,
)
Expand Down Expand Up @@ -117,6 +118,50 @@ def handle_ludwig_function(self):
self.node.metadata,
)

def handle_sklearn_function(self):
"""Handle sklearn functions
Use Sklearn's regression to train models.
"""
try_to_import_sklearn()
from sklearn.linear_model import LinearRegression

assert (
len(self.children) == 1
), "Create sklearn function expects 1 child, finds {}.".format(
len(self.children)
)

aggregated_batch_list = []
child = self.children[0]
for batch in child.exec():
aggregated_batch_list.append(batch)
aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
aggregated_batch.drop_column_alias()

arg_map = {arg.key: arg.value for arg in self.node.metadata}
model = LinearRegression()
Y = aggregated_batch.frames[arg_map["predict"]]
aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
model.fit(X=aggregated_batch.frames, y=Y)
model_path = os.path.join(
self.db.config.get_value("storage", "model_dir"), self.node.name
)
pickle.dump(model, open(model_path, "wb"))
self.node.metadata.append(
FunctionMetadataCatalogEntry("model_path", model_path)
)

impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
io_list = self._resolve_function_io(None)
return (
self.node.name,
impl_path,
self.node.function_type,
io_list,
self.node.metadata,
)

def handle_ultralytics_function(self):
"""Handle Ultralytics functions"""
try_to_import_ultralytics()
Expand Down Expand Up @@ -332,6 +377,14 @@ def exec(self, *args, **kwargs):
io_list,
metadata,
) = self.handle_ludwig_function()
elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
(
name,
impl_path,
function_type,
io_list,
metadata,
) = self.handle_sklearn_function()
elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
(
name,
Expand Down
47 changes: 47 additions & 0 deletions evadb/functions/sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle

import pandas as pd

from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.utils.generic_utils import try_to_import_sklearn


class GenericSklearnModel(AbstractFunction):
@property
def name(self) -> str:
return "GenericSklearnModel"

def setup(self, model_path: str, **kwargs):
try_to_import_sklearn()

self.model = pickle.load(open(model_path, "rb"))

def forward(self, frames: pd.DataFrame) -> pd.DataFrame:
# The last column is the predictor variable column. Hence we do not
# pass that column in the predict method for sklearn.
predictions = self.model.predict(frames.iloc[:, :-1])
predict_df = pd.DataFrame(predictions)
# We need to rename the column of the output dataframe. For this we
# shall rename it to the column name same as that of the last column of
# frames. This is because the last column of frames corresponds to the
# variable we want to predict.
predict_df.rename(columns={0: frames.columns[-1]}, inplace=True)
return predict_df

def to_device(self, device: str):
# TODO figure out how to control the GPU for ludwig models
return self
11 changes: 11 additions & 0 deletions evadb/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,17 @@ def is_forecast_available() -> bool:
return False


def try_to_import_sklearn():
try:
import sklearn # noqa: F401
from sklearn.linear_model import LinearRegression # noqa: F401
except ImportError:
raise ValueError(
"""Could not import sklearn.
Please install it with `pip install scikit-learn`."""
)


##############################
## VISION
##############################
Expand Down
16 changes: 16 additions & 0 deletions test/integration_tests/long/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ def test_ludwig_automl(self):
self.assertEqual(len(result.columns), 1)
self.assertEqual(len(result), 10)

def test_sklearn_regression(self):
create_predict_function = """
CREATE FUNCTION IF NOT EXISTS PredictHouseRent FROM
( SELECT number_of_rooms, number_of_bathrooms, days_on_market, rental_price FROM HomeRentals )
TYPE Sklearn
PREDICT 'rental_price';
"""
execute_query_fetch_all(self.evadb, create_predict_function)

predict_query = """
SELECT PredictHouseRent(number_of_rooms, number_of_bathrooms, days_on_market, rental_price) FROM HomeRentals LIMIT 10;
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result.columns), 1)
self.assertEqual(len(result), 10)


if __name__ == "__main__":
unittest.main()

0 comments on commit 0de4cd2

Please sign in to comment.