Skip to content

Commit

Permalink
[dask] allow parameter aliases for local_listen_port, num_threads, tr…
Browse files Browse the repository at this point in the history
…ee_learner (fixes #3671) (#3789)

* [dask] allow parameter aliases for tree_learner and local_listen_port (fixes #3671)

* num_thread too

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* empty commit

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
jameslamb and StrikerRUS authored Jan 20, 2021
1 parent 4007b34 commit d107872
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
12 changes: 12 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ class _ConfigAliases:
"sparse"},
"label_column": {"label_column",
"label"},
"local_listen_port": {"local_listen_port",
"local_port",
"port"},
"machines": {"machines",
"workers",
"nodes"},
Expand All @@ -255,12 +258,21 @@ class _ConfigAliases:
"num_rounds",
"num_boost_round",
"n_estimators"},
"num_threads": {"num_threads",
"num_thread",
"nthread",
"nthreads",
"n_jobs"},
"objective": {"objective",
"objective_type",
"app",
"application"},
"pre_partition": {"pre_partition",
"is_pre_partition"},
"tree_learner": {"tree_learner",
"tree",
"tree_type",
"tree_learner_type"},
"two_round": {"two_round",
"two_round_loading",
"use_two_round_loading"},
Expand Down
39 changes: 34 additions & 5 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import socket
from collections import defaultdict
from copy import deepcopy
from typing import Dict, Iterable
from urllib.parse import urlparse

Expand All @@ -19,7 +20,7 @@
from dask import delayed
from dask.distributed import Client, default_client, get_worker, wait

from .basic import _LIB, _safe_call
from .basic import _ConfigAliases, _LIB, _safe_call
from .sklearn import LGBMClassifier, LGBMRegressor

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -170,6 +171,8 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
Weights of training data.
"""
params = deepcopy(params)

# Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality
data_parts = _split_to_parts(data, is_matrix=True)
label_parts = _split_to_parts(label, is_matrix=False)
Expand Down Expand Up @@ -197,21 +200,47 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
master_worker = next(iter(worker_map))
worker_ncores = client.ncores()

if 'tree_learner' not in params or params['tree_learner'].lower() not in {'data', 'feature', 'voting'}:
logger.warning('Parameter tree_learner not set or set to incorrect value '
'(%s), using "data" as default', params.get("tree_learner", None))
tree_learner = None
for tree_learner_param in _ConfigAliases.get('tree_learner'):
tree_learner = params.get(tree_learner_param)
if tree_learner is not None:
break

allowed_tree_learners = {
'data',
'data_parallel',
'feature',
'feature_parallel',
'voting',
'voting_parallel'
}
if tree_learner is None:
logger.warning('Parameter tree_learner not set. Using "data" as default')
params['tree_learner'] = 'data'
elif tree_learner.lower() not in allowed_tree_learners:
logger.warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner)
params['tree_learner'] = 'data'

local_listen_port = 12400
for port_param in _ConfigAliases.get('local_listen_port'):
val = params.get(port_param)
if val is not None:
local_listen_port = val
break

# find an open port on each worker. note that multiple workers can run
# on the same machine, so this needs to ensure that each one gets its
# own port
local_listen_port = params.get('local_listen_port', 12400)
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_map.keys(),
local_listen_port=local_listen_port
)

# num_threads is set below, so remove it and all aliases of it from params
for num_thread_alias in _ConfigAliases.get('num_threads'):
params.pop(num_thread_alias, None)

# Tell each worker to train on the parts that it has locally
futures_classifiers = [client.submit(_train_part,
model_factory=model_factory,
Expand Down
11 changes: 7 additions & 4 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_classifier_local_predict(client, listen_port):

dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port,
local_port=listen_port,
n_estimators=10,
num_leaves=10
)
Expand All @@ -148,7 +148,8 @@ def test_regressor(output, client, listen_port):
time_out=5,
local_listen_port=listen_port,
seed=42,
num_leaves=10
num_leaves=10,
tree='data'
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX)
Expand Down Expand Up @@ -181,7 +182,8 @@ def test_regressor_quantile(output, client, listen_port, alpha):
objective='quantile',
alpha=alpha,
n_estimators=10,
num_leaves=10
num_leaves=10,
tree_learner_type='data_parallel'
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX).compute()
Expand Down Expand Up @@ -210,7 +212,8 @@ def test_regressor_local_predict(client, listen_port):
local_listen_port=listen_port,
seed=42,
n_estimators=10,
num_leaves=10
num_leaves=10,
tree_type='data'
)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_regressor.predict(dX)
Expand Down

0 comments on commit d107872

Please sign in to comment.