Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make http handler take an optional requests.Session #825

Merged
merged 5 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions smart_open/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def open_uri(uri, mode, transport_params):


def open(uri, mode, kerberos=False, user=None, password=None, cert=None,
headers=None, timeout=None, buffer_size=DEFAULT_BUFFER_SIZE):
headers=None, timeout=None, session=None, buffer_size=DEFAULT_BUFFER_SIZE):
"""Implement streamed reader from a web site.

Supports Kerberos and Basic HTTP authentication.
Expand All @@ -73,6 +73,9 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None,
Any headers to send in the request. If ``None``, the default headers are sent:
``{'Accept-Encoding': 'identity'}``. To use no headers at all,
set this variable to an empty dict, ``{}``.
session: object, optional
The requests Session object to use with http get requests.
Can be used for OAuth2 clients.
buffer_size: int, optional
The buffer size to use when performing I/O.

Expand All @@ -86,7 +89,7 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None,
fobj = SeekableBufferedInputBase(
uri, mode, buffer_size=buffer_size, kerberos=kerberos,
user=user, password=password, cert=cert,
headers=headers, timeout=timeout,
headers=headers, session=session, timeout=timeout,
)
fobj.name = os.path.basename(urllib.parse.urlparse(uri).path)
return fobj
Expand All @@ -97,7 +100,10 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None,
class BufferedInputBase(io.BufferedIOBase):
def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
kerberos=False, user=None, password=None, cert=None,
headers=None, timeout=None):
headers=None, session=None, timeout=None):

self.session = session

if kerberos:
import requests_kerberos
auth = requests_kerberos.HTTPKerberosAuth()
Expand All @@ -116,7 +122,14 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,

self.timeout = timeout

self.response = requests.get(
self.response = session.get(
url,
auth=auth,
cert=cert,
stream=True,
headers=self.headers,
timeout=self.timeout,
) if session is not None else requests.get(
url,
auth=auth,
cert=cert,
Expand Down Expand Up @@ -217,7 +230,7 @@ class SeekableBufferedInputBase(BufferedInputBase):

def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
kerberos=False, user=None, password=None, cert=None,
headers=None, timeout=None):
headers=None, session=None, timeout=None):
"""
If Kerberos is True, will attempt to use the local Kerberos credentials.
If cert is set, will try to use a client certificate
Expand All @@ -227,6 +240,8 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
"""
self.url = url

self.session = session
mpenkov marked this conversation as resolved.
Show resolved Hide resolved

if kerberos:
import requests_kerberos
self.auth = requests_kerberos.HTTPKerberosAuth()
Expand Down Expand Up @@ -332,12 +347,20 @@ def _partial_request(self, start_pos=None):
if start_pos is not None:
self.headers.update({"range": smart_open.utils.make_range_string(start_pos)})

response = requests.get(
response = self.session.get(
self.url,
auth=self.auth,
stream=True,
cert=self.cert,
headers=self.headers,
timeout=self.timeout,
) if self.session is not None else requests.get(
self.url,
auth=self.auth,
stream=True,
cert=self.cert,
headers=self.headers,
timeout=self.timeout,
)

return response
10 changes: 9 additions & 1 deletion smart_open/tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import smart_open.http
import smart_open.s3
import smart_open.constants

import requests

BYTES = b'i tried so hard and got so far but in the end it doesn\'t even matter'
URL = 'http://localhost'
Expand Down Expand Up @@ -161,6 +161,14 @@ def test_timeout_attribute(self):
assert hasattr(reader, 'timeout')
assert reader.timeout == timeout

@responses.activate
def test_session_attribute(self):
session = requests.Session()
responses.add_callback(responses.GET, URL, callback=request_callback)
reader = smart_open.open(URL, "rb", transport_params={'session': session})
assert hasattr(reader, 'session')
assert reader.session == session
mpenkov marked this conversation as resolved.
Show resolved Hide resolved


@responses.activate
def test_seek_implicitly_enabled(numbytes=10):
Expand Down