Skip to content

Commit

Permalink
fix design issues inside ssh.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mpenkov committed Mar 5, 2024
1 parent 6b7904a commit 2fdb17e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 28 deletions.
54 changes: 28 additions & 26 deletions smart_open/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""

import copy
import getpass
import os
import logging
Expand Down Expand Up @@ -111,39 +112,33 @@ def parse_uri(uri_as_string):


def open_uri(uri, mode, transport_params):
# `connect_kwargs` is a legitimate param *only* for sftp, so this filters it out of validation
# (otherwise every call with this present complains it's not valid)
params_to_validate = {k: v for k, v in transport_params.items() if k != 'connect_kwargs'}
smart_open.utils.check_kwargs(open, params_to_validate)
smart_open.utils.check_kwargs(open, transport_params)
parsed_uri = parse_uri(uri)
uri_path = parsed_uri.pop('uri_path')
parsed_uri.pop('scheme')
return open(uri_path, mode, transport_params=transport_params, **parsed_uri)
connect_kwargs = transport_params.get('connect_kwargs')
return open(uri_path, mode, connect_kwargs=connect_kwargs, **parsed_uri)


def _connect_ssh(hostname, username, port, password, transport_params):
def _connect_ssh(hostname, username, port, password, connect_kwargs):
ssh = paramiko.SSHClient()
ssh.load_system_host_keys()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
kwargs = transport_params.get('connect_kwargs', {}).copy()
kwargs = connect_kwargs.copy()
if 'key_filename' not in kwargs:
kwargs.setdefault('password', password)
kwargs.setdefault('username', username)
ssh.connect(hostname, port, **kwargs)
return ssh


def _maybe_fetch_config(host, username=None, password=None, port=None, transport_params=None):
def _maybe_fetch_config(host, username=None, password=None, port=None, connect_kwargs=None):
# If all fields are set, return as-is.
if not any(arg is None for arg in (host, username, password, port, transport_params)):
return host, username, password, port, transport_params
if not any(arg is None for arg in (host, username, password, port, connect_kwargs)):
return host, username, password, port, connect_kwargs

if not host:
raise ValueError('you must specify the host to connect to')
if not transport_params:
transport_params = {}
if "connect_kwargs" not in transport_params:
transport_params["connect_kwargs"] = {}

# Attempt to load an OpenSSH config.
#
Expand All @@ -160,14 +155,21 @@ def _maybe_fetch_config(host, username=None, password=None, port=None, transport
# - compression selection
# - GSS configuration
#
connect_params = transport_params["connect_kwargs"]
config_files = [f for f in _SSH_CONFIG_FILES if os.path.exists(f)]
#
# This is the actual name of the host. The input host may actually be an
# alias.
#
actual_hostname = ""

#
# Avoid modifying the caller's copy of connect_kwargs, as that would create
# an unexpected side-effect.
#
if connect_kwargs:
connect_kwargs = copy.deepcopy(connect_kwargs)
else:
connect_kwargs = {}
for config_filename in config_files:
try:
cfg = paramiko.SSHConfig.from_path(config_filename)
Expand Down Expand Up @@ -198,14 +200,14 @@ def _maybe_fetch_config(host, username=None, password=None, port=None, transport
# that the identityfile list has len > 0. This should be redundant, but
# keeping it for safety.
#
if connect_params.get("key_filename") is None:
if connect_kwargs.get("key_filename") is None:
identityfile = cfg.get("identityfile", [])
if len(identityfile):
connect_params["key_filename"] = identityfile
connect_kwargs["key_filename"] = identityfile

for param_name, (sshcfg_name, from_str) in _PARAMIKO_CONFIG_MAP.items():
if connect_params.get(param_name) is None and sshcfg_name in cfg:
connect_params[param_name] = from_str(cfg[sshcfg_name])
if connect_kwargs.get(param_name) is None and sshcfg_name in cfg:
connect_kwargs[param_name] = from_str(cfg[sshcfg_name])

#
# Continue working through other config files, if there are any,
Expand All @@ -221,10 +223,10 @@ def _maybe_fetch_config(host, username=None, password=None, port=None, transport
if actual_hostname:
host = actual_hostname

return host, username, password, port, transport_params
return host, username, password, port, connect_kwargs


def open(path, mode='r', host=None, user=None, password=None, port=None, transport_params=None):
def open(path, mode='r', host=None, user=None, password=None, port=None, connect_kwargs=None):
"""Open a file on a remote machine over SSH.
Expects authentication to be already set up via existing keys on the local machine.
Expand All @@ -244,7 +246,7 @@ def open(path, mode='r', host=None, user=None, password=None, port=None, transpo
The password to use to login to the remote machine.
port: int, optional
The port to connect to.
transport_params: dict, optional
connect_kwargs: dict, optional
Any additional settings to be passed to paramiko.SSHClient.connect
Returns
Expand All @@ -259,8 +261,8 @@ def open(path, mode='r', host=None, user=None, password=None, port=None, transpo
If ``username`` or ``password`` are specified in *both* the uri and
``transport_params``, ``transport_params`` will take precedence
"""
host, user, password, port, transport_params = _maybe_fetch_config(
host, user, password, port, transport_params
host, user, password, port, connect_kwargs = _maybe_fetch_config(
host, user, password, port, connect_kwargs
)

key = (host, user)
Expand All @@ -273,9 +275,9 @@ def open(path, mode='r', host=None, user=None, password=None, port=None, transpo
# and if not, refresh the connection
if not ssh.get_transport().active:
ssh.close()
ssh = _SSH[key] = _connect_ssh(host, user, port, password, transport_params)
ssh = _SSH[key] = _connect_ssh(host, user, port, password, connect_kwargs)
except KeyError:
ssh = _SSH[key] = _connect_ssh(host, user, port, password, transport_params)
ssh = _SSH[key] = _connect_ssh(host, user, port, password, connect_kwargs)

try:
transport = ssh.get_transport()
Expand Down
4 changes: 2 additions & 2 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def test_read_ssh(self, mock_open):
obj = smart_open.open(
"ssh://ubuntu:pass@ip_address:1022/some/path/lines.txt",
mode='rb',
transport_params=dict(hello='world'),
transport_params={'connect_kwargs': {'hello': 'world'}},
)
obj.__iter__()
mock_open.assert_called_with(
Expand All @@ -450,7 +450,7 @@ def test_read_ssh(self, mock_open):
user='ubuntu',
password='pass',
port=1022,
transport_params={'hello': 'world'},
connect_kwargs={'hello': 'world'},
)

@responses.activate
Expand Down

0 comments on commit 2fdb17e

Please sign in to comment.