Skip to content

Commit

Permalink
Removing the deprecated categorical_feature parameter from `lightgb…
Browse files Browse the repository at this point in the history
…m.train(...)` function calls. (#454)

Using `categorical_feature` parameter in `lightgbm.Dataset()` instead of
`lightgbm.train(...)` eliminates the following warnings:
```
test/gbdt/test_gbdt.py: 60 warnings
  /usr/local/lib/python3.10/dist-packages/lightgbm/engine.py:187: LGBMDeprecationWarning: Argument 'categorical_feature' 
to train() is deprecated and will be removed in a future release. Set 'categorical_feature' when calling lightgbm.Dataset() 
instead. See microsoft/LightGBM#6435.
    _emit_dataset_kwarg_warning("train", "categorical_feature")
```

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
drivanov and pre-commit-ci[bot] authored Sep 25, 2024
1 parent cae7ccf commit 9c6cc61
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions torch_frame/gbdt/tuned_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def objective(
trial: Any, # optuna.trial.Trial
train_data: Any, # lightgbm.Dataset
eval_data: Any, # lightgbm.Dataset
cat_features: list[int],
num_boost_round: int,
) -> float:
r"""Objective function to be optimized.
Expand All @@ -112,8 +111,6 @@ def objective(
trial (optuna.trial.Trial): Optuna trial object.
train_data (lightgbm.Dataset): Train data.
eval_data (lightgbm.Dataset): Validation data.
cat_features (list[int]): Array containing indexes of
categorical features.
num_boost_round (int): Number of boosting round.
Returns:
Expand Down Expand Up @@ -169,8 +166,7 @@ def objective(

boost = lightgbm.train(
self.params, train_data, num_boost_round=num_boost_round,
categorical_feature=cat_features, valid_sets=[eval_data],
callbacks=[
valid_sets=[eval_data], callbacks=[
lightgbm.early_stopping(stopping_rounds=50, verbose=False),
lightgbm.log_evaluation(period=2000)
])
Expand Down Expand Up @@ -199,19 +195,18 @@ def _tune(
assert train_y is not None
assert val_y is not None
train_data = lightgbm.Dataset(train_x, label=train_y,
categorical_feature=cat_features,
free_raw_data=False)
eval_data = lightgbm.Dataset(val_x, label=val_y, free_raw_data=False)

study.optimize(
lambda trial: self.objective(trial, train_data, eval_data,
cat_features, num_boost_round),
num_trials)
num_boost_round), num_trials)
self.params.update(study.best_params)

self.model = lightgbm.train(
self.params, train_data, num_boost_round=num_boost_round,
categorical_feature=cat_features, valid_sets=[eval_data],
callbacks=[
valid_sets=[eval_data], callbacks=[
lightgbm.early_stopping(stopping_rounds=50, verbose=False),
lightgbm.log_evaluation(period=2000)
])
Expand Down

0 comments on commit 9c6cc61

Please sign in to comment.