Skip to content

Commit

Permalink
Support custom filename to be provided to URLFile (#2004)
Browse files Browse the repository at this point in the history
This commit works around an issue where the basename of the URL many not
actually contain a file extension and the uploader logic cannot infer
the mime type for the file.

We stash the name when pickling and extract it again when unpickling.
The __getattr__ function then supports returning the underlying name
value rather than proxying to the underlying request object.

I also ran into a small bug whereby the __del__ method was triggering
a network request because of some private attributes being accessed
during teardown would trigger the __wrapper__ code. I've overridden
the super class to disable this. Though I'm unclear if this is just the
test suite doing this cleanup.
  • Loading branch information
aron authored Nov 28, 2024
1 parent 7d26f5d commit 9cd4738
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
38 changes: 34 additions & 4 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,9 @@ class URLFile(io.IOBase):
URL that can survive pickling/unpickling.
"""

__slots__ = ("__target__", "__url__")
__slots__ = ("__target__", "__url__", "name")

def __init__(self, url: str) -> None:
object.__setattr__(self, "__url__", url)
def __init__(self, url: str, filename: Optional[str] = None) -> None:
parsed = urllib.parse.urlparse(url)
if parsed.scheme not in {
"http",
Expand All @@ -298,13 +297,42 @@ def __init__(self, url: str) -> None:
"URLFile requires URL to conform to HTTP or HTTPS protocol"
)
object.__setattr__(self, "name", os.path.basename(parsed.path))
object.__setattr__(self, "__url__", url)

if parsed.scheme not in {
"http",
"https",
}:
raise ValueError(
"URLFile requires URL to conform to HTTP or HTTPS protocol"
)

if not filename:
filename = os.path.basename(parsed.path)

object.__setattr__(self, "name", filename)
object.__setattr__(self, "__url__", url)

def __del__(self) -> None:
try:
object.__getattribute__(self, "__target__")
except AttributeError:
# Do nothing when tearing down the object if the response object
# hasn't been created yet.
return

super().__del__()

# We provide __getstate__ and __setstate__ explicitly to ensure that the
# object is always picklable.
def __getstate__(self) -> Dict[str, Any]:
return {"url": object.__getattribute__(self, "__url__")}
return {
"name": object.__getattribute__(self, "name"),
"url": object.__getattribute__(self, "__url__"),
}

def __setstate__(self, state: Dict[str, Any]) -> None:
object.__setattr__(self, "name", state["name"])
object.__setattr__(self, "__url__", state["url"])

# Proxy getattr/setattr/delattr through to the response object.
Expand All @@ -317,6 +345,8 @@ def __setattr__(self, name: str, value: Any) -> None:
def __getattr__(self, name: str) -> Any:
if name in ("__target__", "__wrapped__", "__url__"):
raise AttributeError(name)
elif name == "name":
return object.__getattribute__(self, "name")
return getattr(self.__wrapped__, name)

def __delattr__(self, name: str) -> None:
Expand Down
16 changes: 13 additions & 3 deletions python/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def test_urlfile_protocol_validation():
URLFile("data:text/plain,hello")


def test_urlfile_custom_filename():
u = URLFile("https://example.com/some-path", filename="my_file.txt")
assert u.name == "my_file.txt"


@responses.activate
def test_urlfile_acts_like_response():
responses.get(
Expand Down Expand Up @@ -61,18 +66,23 @@ def test_urlfile_can_be_pickled():

@responses.activate
def test_urlfile_can_be_pickled_even_once_loaded():
responses.get(
mock = responses.get(
"https://example.com/some/url",
json={"message": "hello world"},
status=200,
)

u = URLFile("https://example.com/some/url")
u.read()
u = URLFile("https://example.com/some/url", filename="my_file.txt")
assert u.name == "my_file.txt"
assert u.read() == b'{"message": "hello world"}'

result = pickle.loads(pickle.dumps(u))

assert isinstance(result, URLFile)
assert result.name == "my_file.txt"
assert result.read() == b'{"message": "hello world"}'

assert mock.call_count == 2


@pytest.mark.parametrize(
Expand Down

0 comments on commit 9cd4738

Please sign in to comment.