Skip to content

Commit

Permalink
Revert "[FSTORE-1453] Move client, decorators, variable_api and const…
Browse files Browse the repository at this point in the history
…ants to hopsworks_common (logicalclocks#229)"

This reverts commit 85deafd.
  • Loading branch information
kennethmhc committed Jul 19, 2024
1 parent 85deafd commit d741178
Show file tree
Hide file tree
Showing 32 changed files with 2,394 additions and 2,180 deletions.
83 changes: 58 additions & 25 deletions python/hopsworks/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright 2024 Hopsworks AB
# Copyright 2022 Logical Clocks AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,27 +14,60 @@
# limitations under the License.
#

from hopsworks_common.client import (
auth,
base,
exceptions,
external,
get_instance,
hopsworks,
init,
online_store_rest_client,
stop,
)


__all__ = [
auth,
base,
exceptions,
external,
get_instance,
hopsworks,
init,
online_store_rest_client,
stop,
]
from hopsworks.client import external, hopsworks


_client = None
_python_version = None


def init(
client_type,
host=None,
port=None,
project=None,
hostname_verification=None,
trust_store_path=None,
cert_folder=None,
api_key_file=None,
api_key_value=None,
):
global _client
if not _client:
if client_type == "hopsworks":
_client = hopsworks.Client()
elif client_type == "external":
_client = external.Client(
host,
port,
project,
hostname_verification,
trust_store_path,
cert_folder,
api_key_file,
api_key_value,
)


def get_instance():
global _client
if _client:
return _client
raise Exception("Couldn't find client. Try reconnecting to Hopsworks.")


def get_python_version():
global _python_version
return _python_version


def set_python_version(python_version):
global _python_version
_python_version = python_version


def stop():
global _client
if _client:
_client._close()
_client = None
33 changes: 22 additions & 11 deletions python/hopsworks/client/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright 2024 Hopsworks AB
# Copyright 2022 Logical Clocks AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,15 +14,26 @@
# limitations under the License.
#

from hopsworks_common.client.auth import (
ApiKeyAuth,
BearerAuth,
OnlineStoreKeyAuth,
)
import requests


__all__ = [
ApiKeyAuth,
BearerAuth,
OnlineStoreKeyAuth,
]
class BearerAuth(requests.auth.AuthBase):
"""Class to encapsulate a Bearer token."""

def __init__(self, token):
self._token = token

def __call__(self, r):
r.headers["Authorization"] = "Bearer " + self._token.strip()
return r


class ApiKeyAuth(requests.auth.AuthBase):
"""Class to encapsulate an API key."""

def __init__(self, token):
self._token = token

def __call__(self, r):
r.headers["Authorization"] = "ApiKey " + self._token.strip()
return r
175 changes: 168 additions & 7 deletions python/hopsworks/client/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright 2024 Hopsworks AB
# Copyright 2022 Logical Clocks AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,11 +14,172 @@
# limitations under the License.
#

from hopsworks_common.client.base import (
Client,
)
import os
from abc import ABC, abstractmethod

import furl
import requests
import urllib3
from hopsworks.client import auth, exceptions
from hopsworks.decorators import connected

__all__ = [
Client,
]

urllib3.disable_warnings(urllib3.exceptions.SecurityWarning)
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


class Client(ABC):
TOKEN_FILE = "token.jwt"
APIKEY_FILE = "api.key"
REST_ENDPOINT = "REST_ENDPOINT"
HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST"

@abstractmethod
def __init__(self):
"""To be implemented by clients."""
pass

def _get_verify(self, verify, trust_store_path):
"""Get verification method for sending HTTP requests to Hopsworks.
Credit to https://gist.github.com/gdamjan/55a8b9eec6cf7b771f92021d93b87b2c
:param verify: perform hostname verification, 'true' or 'false'
:type verify: str
:param trust_store_path: path of the truststore locally if it was uploaded manually to
the external environment such as AWS Sagemaker
:type trust_store_path: str
:return: if verify is true and the truststore is provided, then return the trust store location
if verify is true but the truststore wasn't provided, then return true
if verify is false, then return false
:rtype: str or boolean
"""
if verify == "true":
if trust_store_path is not None:
return trust_store_path
else:
return True

return False

def _get_host_port_pair(self):
"""
Removes "http or https" from the rest endpoint and returns a list
[endpoint, port], where endpoint is on the format /path.. without http://
:return: a list [endpoint, port]
:rtype: list
"""
endpoint = self._base_url
if "http" in endpoint:
last_index = endpoint.rfind("/")
endpoint = endpoint[last_index + 1 :]
host, port = endpoint.split(":")
return host, port

def _read_jwt(self):
"""Retrieve jwt from local container."""
return self._read_file(self.TOKEN_FILE)

def _read_apikey(self):
"""Retrieve apikey from local container."""
return self._read_file(self.APIKEY_FILE)

def _read_file(self, secret_file):
"""Retrieve secret from local container."""
with open(os.path.join(self._secrets_dir, secret_file), "r") as secret:
return secret.read()

def _get_credentials(self, project_id):
"""Makes a REST call to hopsworks for getting the project user certificates needed to connect to services such as Hive
:param project_id: id of the project
:type project_id: int
:return: JSON response with credentials
:rtype: dict
"""
return self._send_request("GET", ["project", project_id, "credentials"])

def _write_pem_file(self, content: str, path: str) -> None:
with open(path, "w") as f:
f.write(content)

@connected
def _send_request(
self,
method,
path_params,
query_params=None,
headers=None,
data=None,
stream=False,
files=None,
with_base_path_params=True,
):
"""Send REST request to Hopsworks.
Uses the client it is executed from. Path parameters are url encoded automatically.
:param method: 'GET', 'PUT' or 'POST'
:type method: str
:param path_params: a list of path params to build the query url from starting after
the api resource, for example `["project", 119, "featurestores", 67]`.
:type path_params: list
:param query_params: A dictionary of key/value pairs to be added as query parameters,
defaults to None
:type query_params: dict, optional
:param headers: Additional header information, defaults to None
:type headers: dict, optional
:param data: The payload as a python dictionary to be sent as json, defaults to None
:type data: dict, optional
:param stream: Set if response should be a stream, defaults to False
:type stream: boolean, optional
:param files: dictionary for multipart encoding upload
:type files: dict, optional
:raises RestAPIError: Raised when request wasn't correctly received, understood or accepted
:return: Response json
:rtype: dict
"""
f_url = furl.furl(self._base_url)
if with_base_path_params:
base_path_params = ["hopsworks-api", "api"]
f_url.path.segments = base_path_params + path_params
else:
f_url.path.segments = path_params
url = str(f_url)

request = requests.Request(
method,
url=url,
headers=headers,
data=data,
params=query_params,
auth=self._auth,
files=files,
)

prepped = self._session.prepare_request(request)
response = self._session.send(prepped, verify=self._verify, stream=stream)

if response.status_code == 401 and self.REST_ENDPOINT in os.environ:
# refresh token and retry request - only on hopsworks
self._auth = auth.BearerAuth(self._read_jwt())
# Update request with the new token
request.auth = self._auth
prepped = self._session.prepare_request(request)
response = self._session.send(prepped, verify=self._verify, stream=stream)

if response.status_code // 100 != 2:
raise exceptions.RestAPIError(url, response)

if stream:
return response
else:
# handle different success response codes
if len(response.content) == 0:
return None
return response.json()

def _close(self):
"""Closes a client. Can be implemented for clean up purposes, not mandatory."""
self._connected = False
Loading

0 comments on commit d741178

Please sign in to comment.