From d74117896cb62c6b2d39204e9d87e81ccfdc28fd Mon Sep 17 00:00:00 2001 From: kenneth Date: Fri, 19 Jul 2024 10:27:00 +0200 Subject: [PATCH] Revert "[FSTORE-1453] Move client, decorators, variable_api and constants to hopsworks_common (#229)" This reverts commit 85deafd4da93c99151a0e996ffbcfb66a2893185. --- python/hopsworks/client/__init__.py | 83 ++-- python/hopsworks/client/auth.py | 33 +- python/hopsworks/client/base.py | 175 +++++++- python/hopsworks/client/exceptions.py | 104 +++-- python/hopsworks/client/external.py | 157 ++++++- python/hopsworks/client/hopsworks.py | 227 +++++++++- .../client/online_store_rest_client.py | 28 -- python/hopsworks/connection.py | 7 + python/hopsworks/core/variable_api.py | 106 ++++- python/hopsworks/decorators.py | 59 ++- python/hopsworks_common/client/__init__.py | 72 ---- python/hopsworks_common/client/auth.py | 52 --- python/hopsworks_common/client/base.py | 293 ------------- python/hopsworks_common/client/exceptions.py | 143 ------ python/hopsworks_common/client/external.py | 407 ------------------ python/hopsworks_common/client/hopsworks.py | 236 ---------- .../client/online_store_rest_client.py | 385 ----------------- python/hopsworks_common/core/constants.py | 51 --- python/hopsworks_common/core/variable_api.py | 117 ----- python/hopsworks_common/decorators.py | 86 ---- python/hsfs/client/__init__.py | 80 ++-- python/hsfs/client/auth.py | 45 +- python/hsfs/client/base.py | 278 +++++++++++- python/hsfs/client/exceptions.py | 130 ++++-- python/hsfs/client/external.py | 372 +++++++++++++++- python/hsfs/client/hopsworks.py | 175 +++++++- .../hsfs/client/online_store_rest_client.py | 376 +++++++++++++++- python/hsfs/core/constants.py | 67 ++- python/hsfs/core/variable_api.py | 68 ++- python/hsfs/decorators.py | 83 +++- python/hsml/decorators.py | 59 ++- .../core/test_online_store_rest_client.py | 20 +- 32 files changed, 2394 insertions(+), 2180 deletions(-) delete mode 100644 python/hopsworks/client/online_store_rest_client.py delete mode 100644 python/hopsworks_common/client/__init__.py delete mode 100644 python/hopsworks_common/client/auth.py delete mode 100644 python/hopsworks_common/client/base.py delete mode 100644 python/hopsworks_common/client/exceptions.py delete mode 100644 python/hopsworks_common/client/external.py delete mode 100644 python/hopsworks_common/client/hopsworks.py delete mode 100644 python/hopsworks_common/client/online_store_rest_client.py delete mode 100644 python/hopsworks_common/core/constants.py delete mode 100644 python/hopsworks_common/core/variable_api.py delete mode 100644 python/hopsworks_common/decorators.py diff --git a/python/hopsworks/client/__init__.py b/python/hopsworks/client/__init__.py index 19e0feb8d..004e49c8b 100644 --- a/python/hopsworks/client/__init__.py +++ b/python/hopsworks/client/__init__.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2022 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,27 +14,60 @@ # limitations under the License. # -from hopsworks_common.client import ( - auth, - base, - exceptions, - external, - get_instance, - hopsworks, - init, - online_store_rest_client, - stop, -) - - -__all__ = [ - auth, - base, - exceptions, - external, - get_instance, - hopsworks, - init, - online_store_rest_client, - stop, -] +from hopsworks.client import external, hopsworks + + +_client = None +_python_version = None + + +def init( + client_type, + host=None, + port=None, + project=None, + hostname_verification=None, + trust_store_path=None, + cert_folder=None, + api_key_file=None, + api_key_value=None, +): + global _client + if not _client: + if client_type == "hopsworks": + _client = hopsworks.Client() + elif client_type == "external": + _client = external.Client( + host, + port, + project, + hostname_verification, + trust_store_path, + cert_folder, + api_key_file, + api_key_value, + ) + + +def get_instance(): + global _client + if _client: + return _client + raise Exception("Couldn't find client. Try reconnecting to Hopsworks.") + + +def get_python_version(): + global _python_version + return _python_version + + +def set_python_version(python_version): + global _python_version + _python_version = python_version + + +def stop(): + global _client + if _client: + _client._close() + _client = None diff --git a/python/hopsworks/client/auth.py b/python/hopsworks/client/auth.py index e912b1daf..8bbd4ae53 100644 --- a/python/hopsworks/client/auth.py +++ b/python/hopsworks/client/auth.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2022 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,15 +14,26 @@ # limitations under the License. # -from hopsworks_common.client.auth import ( - ApiKeyAuth, - BearerAuth, - OnlineStoreKeyAuth, -) +import requests -__all__ = [ - ApiKeyAuth, - BearerAuth, - OnlineStoreKeyAuth, -] +class BearerAuth(requests.auth.AuthBase): + """Class to encapsulate a Bearer token.""" + + def __init__(self, token): + self._token = token + + def __call__(self, r): + r.headers["Authorization"] = "Bearer " + self._token.strip() + return r + + +class ApiKeyAuth(requests.auth.AuthBase): + """Class to encapsulate an API key.""" + + def __init__(self, token): + self._token = token + + def __call__(self, r): + r.headers["Authorization"] = "ApiKey " + self._token.strip() + return r diff --git a/python/hopsworks/client/base.py b/python/hopsworks/client/base.py index 3ff35d800..852259639 100644 --- a/python/hopsworks/client/base.py +++ b/python/hopsworks/client/base.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2022 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +14,172 @@ # limitations under the License. # -from hopsworks_common.client.base import ( - Client, -) +import os +from abc import ABC, abstractmethod +import furl +import requests +import urllib3 +from hopsworks.client import auth, exceptions +from hopsworks.decorators import connected -__all__ = [ - Client, -] + +urllib3.disable_warnings(urllib3.exceptions.SecurityWarning) +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + +class Client(ABC): + TOKEN_FILE = "token.jwt" + APIKEY_FILE = "api.key" + REST_ENDPOINT = "REST_ENDPOINT" + HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST" + + @abstractmethod + def __init__(self): + """To be implemented by clients.""" + pass + + def _get_verify(self, verify, trust_store_path): + """Get verification method for sending HTTP requests to Hopsworks. + + Credit to https://gist.github.com/gdamjan/55a8b9eec6cf7b771f92021d93b87b2c + + :param verify: perform hostname verification, 'true' or 'false' + :type verify: str + :param trust_store_path: path of the truststore locally if it was uploaded manually to + the external environment such as AWS Sagemaker + :type trust_store_path: str + :return: if verify is true and the truststore is provided, then return the trust store location + if verify is true but the truststore wasn't provided, then return true + if verify is false, then return false + :rtype: str or boolean + """ + if verify == "true": + if trust_store_path is not None: + return trust_store_path + else: + return True + + return False + + def _get_host_port_pair(self): + """ + Removes "http or https" from the rest endpoint and returns a list + [endpoint, port], where endpoint is on the format /path.. without http:// + + :return: a list [endpoint, port] + :rtype: list + """ + endpoint = self._base_url + if "http" in endpoint: + last_index = endpoint.rfind("/") + endpoint = endpoint[last_index + 1 :] + host, port = endpoint.split(":") + return host, port + + def _read_jwt(self): + """Retrieve jwt from local container.""" + return self._read_file(self.TOKEN_FILE) + + def _read_apikey(self): + """Retrieve apikey from local container.""" + return self._read_file(self.APIKEY_FILE) + + def _read_file(self, secret_file): + """Retrieve secret from local container.""" + with open(os.path.join(self._secrets_dir, secret_file), "r") as secret: + return secret.read() + + def _get_credentials(self, project_id): + """Makes a REST call to hopsworks for getting the project user certificates needed to connect to services such as Hive + + :param project_id: id of the project + :type project_id: int + :return: JSON response with credentials + :rtype: dict + """ + return self._send_request("GET", ["project", project_id, "credentials"]) + + def _write_pem_file(self, content: str, path: str) -> None: + with open(path, "w") as f: + f.write(content) + + @connected + def _send_request( + self, + method, + path_params, + query_params=None, + headers=None, + data=None, + stream=False, + files=None, + with_base_path_params=True, + ): + """Send REST request to Hopsworks. + + Uses the client it is executed from. Path parameters are url encoded automatically. + + :param method: 'GET', 'PUT' or 'POST' + :type method: str + :param path_params: a list of path params to build the query url from starting after + the api resource, for example `["project", 119, "featurestores", 67]`. + :type path_params: list + :param query_params: A dictionary of key/value pairs to be added as query parameters, + defaults to None + :type query_params: dict, optional + :param headers: Additional header information, defaults to None + :type headers: dict, optional + :param data: The payload as a python dictionary to be sent as json, defaults to None + :type data: dict, optional + :param stream: Set if response should be a stream, defaults to False + :type stream: boolean, optional + :param files: dictionary for multipart encoding upload + :type files: dict, optional + :raises RestAPIError: Raised when request wasn't correctly received, understood or accepted + :return: Response json + :rtype: dict + """ + f_url = furl.furl(self._base_url) + if with_base_path_params: + base_path_params = ["hopsworks-api", "api"] + f_url.path.segments = base_path_params + path_params + else: + f_url.path.segments = path_params + url = str(f_url) + + request = requests.Request( + method, + url=url, + headers=headers, + data=data, + params=query_params, + auth=self._auth, + files=files, + ) + + prepped = self._session.prepare_request(request) + response = self._session.send(prepped, verify=self._verify, stream=stream) + + if response.status_code == 401 and self.REST_ENDPOINT in os.environ: + # refresh token and retry request - only on hopsworks + self._auth = auth.BearerAuth(self._read_jwt()) + # Update request with the new token + request.auth = self._auth + prepped = self._session.prepare_request(request) + response = self._session.send(prepped, verify=self._verify, stream=stream) + + if response.status_code // 100 != 2: + raise exceptions.RestAPIError(url, response) + + if stream: + return response + else: + # handle different success response codes + if len(response.content) == 0: + return None + return response.json() + + def _close(self): + """Closes a client. Can be implemented for clean up purposes, not mandatory.""" + self._connected = False diff --git a/python/hopsworks/client/exceptions.py b/python/hopsworks/client/exceptions.py index b34ef198f..637146492 100644 --- a/python/hopsworks/client/exceptions.py +++ b/python/hopsworks/client/exceptions.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2022 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,37 +14,71 @@ # limitations under the License. # -from hopsworks_common.client.exceptions import ( - DatasetException, - DataValidationException, - EnvironmentException, - ExternalClientError, - FeatureStoreException, - GitException, - JobException, - JobExecutionException, - KafkaException, - OpenSearchException, - ProjectException, - RestAPIError, - UnknownSecretStorageError, - VectorDatabaseException, -) - - -__all__ = [ - DatasetException, - DataValidationException, - EnvironmentException, - ExternalClientError, - FeatureStoreException, - GitException, - JobException, - JobExecutionException, - KafkaException, - OpenSearchException, - ProjectException, - RestAPIError, - UnknownSecretStorageError, - VectorDatabaseException, -] + +class RestAPIError(Exception): + """REST Exception encapsulating the response object and url.""" + + def __init__(self, url, response): + try: + error_object = response.json() + except Exception: + error_object = {} + message = ( + "Metadata operation error: (url: {}). Server response: \n" + "HTTP code: {}, HTTP reason: {}, body: {}, error code: {}, error msg: {}, user " + "msg: {}".format( + url, + response.status_code, + response.reason, + response.content, + error_object.get("errorCode", ""), + error_object.get("errorMsg", ""), + error_object.get("usrMsg", ""), + ) + ) + super().__init__(message) + self.url = url + self.response = response + + +class UnknownSecretStorageError(Exception): + """This exception will be raised if an unused secrets storage is passed as a parameter.""" + + +class GitException(Exception): + """Generic git exception""" + + +class JobException(Exception): + """Generic job exception""" + + +class EnvironmentException(Exception): + """Generic python environment exception""" + + +class KafkaException(Exception): + """Generic kafka exception""" + + +class DatasetException(Exception): + """Generic dataset exception""" + + +class ProjectException(Exception): + """Generic project exception""" + + +class OpenSearchException(Exception): + """Generic opensearch exception""" + + +class JobExecutionException(Exception): + """Generic job executions exception""" + + +class ExternalClientError(TypeError): + """Raised when external client cannot be initialized due to missing arguments.""" + + def __init__(self, message): + super().__init__(message) diff --git a/python/hopsworks/client/external.py b/python/hopsworks/client/external.py index 1384b1c20..d0a277e71 100644 --- a/python/hopsworks/client/external.py +++ b/python/hopsworks/client/external.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2022 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +14,154 @@ # limitations under the License. # -from hopsworks_common.client.external import ( - Client, -) +import base64 +import os +import requests +from hopsworks.client import auth, base, exceptions -__all__ = [ - Client, -] + +class Client(base.Client): + def __init__( + self, + host, + port, + project, + hostname_verification, + trust_store_path, + cert_folder, + api_key_file, + api_key_value, + ): + """Initializes a client in an external environment such as AWS Sagemaker.""" + if not host: + raise exceptions.ExternalClientError("host") + + self._host = host + self._port = port + self._base_url = "https://" + self._host + ":" + str(self._port) + self._project_name = project + if project is not None: + project_info = self._get_project_info(project) + self._project_id = str(project_info["projectId"]) + else: + self._project_id = None + + if api_key_value is not None: + api_key = api_key_value + elif api_key_file is not None: + file = None + if os.path.exists(api_key_file): + try: + file = open(api_key_file, mode="r") + api_key = file.read() + finally: + file.close() + else: + raise IOError( + "Could not find api key file on path: {}".format(api_key_file) + ) + else: + raise exceptions.ExternalClientError( + "Either api_key_file or api_key_value must be set when connecting to" + " hopsworks from an external environment." + ) + + self._auth = auth.ApiKeyAuth(api_key) + + self._session = requests.session() + self._connected = True + self._verify = self._get_verify(self._host, trust_store_path) + + self._cert_folder_base = os.path.join(cert_folder, host) + + def download_certs(self, project_name): + project_info = self._get_project_info(project_name) + project_id = str(project_info["projectId"]) + + project_cert_folder = os.path.join(self._cert_folder_base, project_name) + + trust_store_path = os.path.join(project_cert_folder, "trustStore.jks") + key_store_path = os.path.join(project_cert_folder, "keyStore.jks") + + os.makedirs(project_cert_folder, exist_ok=True) + credentials = self._get_credentials(project_id) + self._write_b64_cert_to_bytes( + str(credentials["kStore"]), + path=key_store_path, + ) + self._write_b64_cert_to_bytes( + str(credentials["tStore"]), + path=trust_store_path, + ) + + self._write_pem_file( + credentials["caChain"], self._get_ca_chain_path(project_name) + ) + self._write_pem_file( + credentials["clientCert"], self._get_client_cert_path(project_name) + ) + self._write_pem_file( + credentials["clientKey"], self._get_client_key_path(project_name) + ) + + with open(os.path.join(project_cert_folder, "material_passwd"), "w") as f: + f.write(str(credentials["password"])) + + def _close(self): + """Closes a client and deletes certificates.""" + # TODO: Implement certificate cleanup. Currently do not remove certificates as it may break users using hsfs python ingestion. + self._connected = False + + def _get_jks_trust_store_path(self): + return self._trust_store_path + + def _get_jks_key_store_path(self): + return self._key_store_path + + def _get_ca_chain_path(self, project_name) -> str: + return os.path.join(self._cert_folder_base, project_name, "ca_chain.pem") + + def _get_client_cert_path(self, project_name) -> str: + return os.path.join(self._cert_folder_base, project_name, "client_cert.pem") + + def _get_client_key_path(self, project_name) -> str: + return os.path.join(self._cert_folder_base, project_name, "client_key.pem") + + def _get_project_info(self, project_name): + """Makes a REST call to hopsworks to get all metadata of a project for the provided project. + + :param project_name: the name of the project + :type project_name: str + :return: JSON response with project info + :rtype: dict + """ + return self._send_request("GET", ["project", "getProjectInfo", project_name]) + + def _write_b64_cert_to_bytes(self, b64_string, path): + """Converts b64 encoded certificate to bytes file . + + :param b64_string: b64 encoded string of certificate + :type b64_string: str + :param path: path where file is saved, including file name. e.g. /path/key-store.jks + :type path: str + """ + + with open(path, "wb") as f: + cert_b64 = base64.b64decode(b64_string) + f.write(cert_b64) + + def _cleanup_file(self, file_path): + """Removes local files with `file_path`.""" + try: + os.remove(file_path) + except OSError: + pass + + def replace_public_host(self, url): + """no need to replace as we are already in external client""" + return url + + @property + def host(self): + return self._host diff --git a/python/hopsworks/client/hopsworks.py b/python/hopsworks/client/hopsworks.py index c360b8cb9..514e3fe48 100644 --- a/python/hopsworks/client/hopsworks.py +++ b/python/hopsworks/client/hopsworks.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2022 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +14,224 @@ # limitations under the License. # -from hopsworks_common.client.hopsworks import ( - Client, -) +import base64 +import os +import textwrap +from pathlib import Path +import requests +from hopsworks.client import auth, base -__all__ = [ - Client, -] + +try: + import jks +except ImportError: + pass + + +class Client(base.Client): + REQUESTS_VERIFY = "REQUESTS_VERIFY" + DOMAIN_CA_TRUSTSTORE_PEM = "DOMAIN_CA_TRUSTSTORE_PEM" + PROJECT_ID = "HOPSWORKS_PROJECT_ID" + PROJECT_NAME = "HOPSWORKS_PROJECT_NAME" + HADOOP_USER_NAME = "HADOOP_USER_NAME" + MATERIAL_DIRECTORY = "MATERIAL_DIRECTORY" + HDFS_USER = "HDFS_USER" + T_CERTIFICATE = "t_certificate" + K_CERTIFICATE = "k_certificate" + TRUSTSTORE_SUFFIX = "__tstore.jks" + KEYSTORE_SUFFIX = "__kstore.jks" + PEM_CA_CHAIN = "ca_chain.pem" + CERT_KEY_SUFFIX = "__cert.key" + MATERIAL_PWD = "material_passwd" + SECRETS_DIR = "SECRETS_DIR" + + def __init__(self): + """Initializes a client being run from a job/notebook directly on Hopsworks.""" + self._base_url = self._get_hopsworks_rest_endpoint() + self._host, self._port = self._get_host_port_pair() + self._secrets_dir = ( + os.environ[self.SECRETS_DIR] if self.SECRETS_DIR in os.environ else "" + ) + self._cert_key = self._get_cert_pw() + trust_store_path = self._get_trust_store_path() + hostname_verification = ( + os.environ[self.REQUESTS_VERIFY] + if self.REQUESTS_VERIFY in os.environ + else "true" + ) + self._project_id = os.environ[self.PROJECT_ID] + self._project_name = self._project_name() + try: + self._auth = auth.BearerAuth(self._read_jwt()) + except FileNotFoundError: + self._auth = auth.ApiKeyAuth(self._read_apikey()) + self._verify = self._get_verify(hostname_verification, trust_store_path) + self._session = requests.session() + + self._connected = True + + credentials = self._get_credentials(self._project_id) + + self._write_pem_file( + credentials["caChain"], self._get_ca_chain_path(self._project_name) + ) + self._write_pem_file( + credentials["clientCert"], self._get_client_cert_path(self._project_name) + ) + self._write_pem_file( + credentials["clientKey"], self._get_client_key_path(self._project_name) + ) + + def _get_hopsworks_rest_endpoint(self): + """Get the hopsworks REST endpoint for making requests to the REST API.""" + return os.environ[self.REST_ENDPOINT] + + def _get_trust_store_path(self): + """Convert truststore from jks to pem and return the location""" + ca_chain_path = Path(self.PEM_CA_CHAIN) + if not ca_chain_path.exists(): + self._write_ca_chain(ca_chain_path) + return str(ca_chain_path) + + def _get_ca_chain_path(self, project_name) -> str: + return os.path.join("/tmp", "ca_chain.pem") + + def _get_client_cert_path(self, project_name) -> str: + return os.path.join("/tmp", "client_cert.pem") + + def _get_client_key_path(self, project_name) -> str: + return os.path.join("/tmp", "client_key.pem") + + def _write_ca_chain(self, ca_chain_path): + """ + Converts JKS trustore file into PEM to be compatible with Python libraries + """ + keystore_pw = self._cert_key + keystore_ca_cert = self._convert_jks_to_pem( + self._get_jks_key_store_path(), keystore_pw + ) + truststore_ca_cert = self._convert_jks_to_pem( + self._get_jks_trust_store_path(), keystore_pw + ) + + with ca_chain_path.open("w") as f: + f.write(keystore_ca_cert + truststore_ca_cert) + + def _convert_jks_to_pem(self, jks_path, keystore_pw): + """ + Converts a keystore JKS that contains client private key, + client certificate and CA certificate that was used to + sign the certificate to PEM format and returns the CA certificate. + Args: + :jks_path: path to the JKS file + :pw: password for decrypting the JKS file + Returns: + strings: (ca_cert) + """ + # load the keystore and decrypt it with password + ks = jks.KeyStore.load(jks_path, keystore_pw, try_decrypt_keys=True) + ca_certs = "" + + # Convert CA Certificates into PEM format and append to string + for _alias, c in ks.certs.items(): + ca_certs = ca_certs + self._bytes_to_pem_str(c.cert, "CERTIFICATE") + return ca_certs + + def _bytes_to_pem_str(self, der_bytes, pem_type): + """ + Utility function for creating PEM files + + Args: + der_bytes: DER encoded bytes + pem_type: type of PEM, e.g Certificate, Private key, or RSA private key + + Returns: + PEM String for a DER-encoded certificate or private key + """ + pem_str = "" + pem_str = pem_str + "-----BEGIN {}-----".format(pem_type) + "\n" + pem_str = ( + pem_str + + "\r\n".join( + textwrap.wrap(base64.b64encode(der_bytes).decode("ascii"), 64) + ) + + "\n" + ) + pem_str = pem_str + "-----END {}-----".format(pem_type) + "\n" + return pem_str + + def _get_jks_trust_store_path(self): + """ + Get truststore location + + Returns: + truststore location + """ + t_certificate = Path(self.T_CERTIFICATE) + if t_certificate.exists(): + return str(t_certificate) + else: + username = os.environ[self.HADOOP_USER_NAME] + material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) + return str(material_directory.joinpath(username + self.TRUSTSTORE_SUFFIX)) + + def _get_jks_key_store_path(self): + """ + Get keystore location + + Returns: + keystore location + """ + k_certificate = Path(self.K_CERTIFICATE) + if k_certificate.exists(): + return str(k_certificate) + else: + username = os.environ[self.HADOOP_USER_NAME] + material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) + return str(material_directory.joinpath(username + self.KEYSTORE_SUFFIX)) + + def _project_name(self): + try: + return os.environ[self.PROJECT_NAME] + except KeyError: + pass + + hops_user = self._project_user() + hops_user_split = hops_user.split( + "__" + ) # project users have username project__user + project = hops_user_split[0] + return project + + def _project_user(self): + try: + hops_user = os.environ[self.HADOOP_USER_NAME] + except KeyError: + hops_user = os.environ[self.HDFS_USER] + return hops_user + + def _get_cert_pw(self): + """ + Get keystore password from local container + + Returns: + Certificate password + """ + pwd_path = Path(self.MATERIAL_PWD) + if not pwd_path.exists(): + username = os.environ[self.HADOOP_USER_NAME] + material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) + pwd_path = material_directory.joinpath(username + self.CERT_KEY_SUFFIX) + + with pwd_path.open() as f: + return f.read() + + def replace_public_host(self, url): + """replace hostname to public hostname set in HOPSWORKS_PUBLIC_HOST""" + ui_url = url._replace(netloc=os.environ[self.HOPSWORKS_PUBLIC_HOST]) + return ui_url + + @property + def host(self): + return self._host diff --git a/python/hopsworks/client/online_store_rest_client.py b/python/hopsworks/client/online_store_rest_client.py deleted file mode 100644 index c75be81b7..000000000 --- a/python/hopsworks/client/online_store_rest_client.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Copyright 2024 Hopsworks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from hopsworks_common.client.online_store_rest_client import ( - OnlineStoreRestClientSingleton, - get_instance, - init_or_reset_online_store_rest_client, -) - - -__all__ = [ - OnlineStoreRestClientSingleton, - get_instance, - init_or_reset_online_store_rest_client, -] diff --git a/python/hopsworks/connection.py b/python/hopsworks/connection.py index c43cfeeb9..61f2e3d6a 100644 --- a/python/hopsworks/connection.py +++ b/python/hopsworks/connection.py @@ -215,6 +215,12 @@ def _check_compatibility(self): ) sys.stderr.flush() + def _set_client_variables(self): + python_version = self._variable_api.get_variable( + "docker_base_image_python_version" + ) + client.set_python_version(python_version) + @not_connected def connect(self): """Instantiate the connection. @@ -265,6 +271,7 @@ def connect(self): ) self._check_compatibility() + self._set_client_variables() def close(self): """Close a connection gracefully. diff --git a/python/hopsworks/core/variable_api.py b/python/hopsworks/core/variable_api.py index 9d6e9765f..d4e8d188c 100644 --- a/python/hopsworks/core/variable_api.py +++ b/python/hopsworks/core/variable_api.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2022 Hopsworks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,104 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations -from hopsworks_common.core.variable_api import ( - VariableApi, -) +import re +from typing import Optional, Tuple +from hopsworks import client +from hopsworks.client.exceptions import RestAPIError -__all__ = [ - VariableApi, -] + +class VariableApi: + def __init__(self): + pass + + def get_variable(self, variable: str): + """Get the configured value of a variable. + + # Arguments + vairable: Name of the variable. + # Returns + The vairable's value + # Raises + `RestAPIError`: If unable to get the variable + """ + + _client = client.get_instance() + + path_params = ["variables", variable] + domain = _client._send_request("GET", path_params) + + return domain["successMessage"] + + def get_version(self, software: str) -> Optional[str]: + """Get version of a software component. + + # Arguments + software: Name of the software. + # Returns + The software's version, if the software is available, otherwise `None`. + # Raises + `RestAPIError`: If unable to get the version + """ + + _client = client.get_instance() + + path_params = ["variables", "versions"] + resp = _client._send_request("GET", path_params) + + for entry in resp: + if entry["software"] == software: + return entry["version"] + return None + + def parse_major_and_minor( + self, backend_version: str + ) -> Tuple[Optional[str], Optional[str]]: + """Extract major and minor version from full version. + + # Arguments + backend_version: The full version. + # Returns + (major, minor): The pair of major and minor parts of the version, or (None, None) if the version format is incorrect. + """ + + version_pattern = r"(\d+)\.(\d+)" + matches = re.match(version_pattern, backend_version) + + if matches is None: + return (None, None) + return matches.group(1), matches.group(2) + + def get_flyingduck_enabled(self) -> bool: + """Check if Flying Duck is enabled on the backend. + + # Returns + `True`: If flying duck is availalbe, `False` otherwise. + # Raises + `RestAPIError`: If unable to obtain the flag's value. + """ + return self.get_variable("enable_flyingduck") == "true" + + def get_loadbalancer_external_domain(self) -> str: + """Get domain of external loadbalancer. + + # Returns + `str`: The domain of external loadbalancer, if it is set up, otherwise empty string `""`. + """ + try: + return self.get_variable("loadbalancer_external_domain") + except RestAPIError: + return "" + + def get_service_discovery_domain(self) -> str: + """Get domain of service discovery server. + + # Returns + `str`: The domain of service discovery server, if it is set up, otherwise empty string `""`. + """ + try: + return self.get_variable("service_discovery_domain") + except RestAPIError: + return "" diff --git a/python/hopsworks/decorators.py b/python/hopsworks/decorators.py index 1165a2daa..51b7d635a 100644 --- a/python/hopsworks/decorators.py +++ b/python/hopsworks/decorators.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2022 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,21 +14,42 @@ # limitations under the License. # -from hopsworks_common.decorators import ( - HopsworksConnectionError, - NoHopsworksConnectionError, - connected, - not_connected, - typechecked, - uses_great_expectations, -) - - -__all__ = [ - HopsworksConnectionError, - NoHopsworksConnectionError, - connected, - not_connected, - typechecked, - uses_great_expectations, -] +import functools + + +def not_connected(fn): + @functools.wraps(fn) + def if_not_connected(inst, *args, **kwargs): + if inst._connected: + raise HopsworksConnectionError + return fn(inst, *args, **kwargs) + + return if_not_connected + + +def connected(fn): + @functools.wraps(fn) + def if_connected(inst, *args, **kwargs): + if not inst._connected: + raise NoHopsworksConnectionError + return fn(inst, *args, **kwargs) + + return if_connected + + +class HopsworksConnectionError(Exception): + """Thrown when attempted to change connection attributes while connected.""" + + def __init__(self): + super().__init__( + "Connection is currently in use. Needs to be closed for modification." + ) + + +class NoHopsworksConnectionError(Exception): + """Thrown when attempted to perform operation on connection while not connected.""" + + def __init__(self): + super().__init__( + "Connection is not active. Needs to be connected for hopsworks operations." + ) diff --git a/python/hopsworks_common/client/__init__.py b/python/hopsworks_common/client/__init__.py deleted file mode 100644 index 2cd86bb83..000000000 --- a/python/hopsworks_common/client/__init__.py +++ /dev/null @@ -1,72 +0,0 @@ -# -# Copyright 2022 Logical Clocks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import annotations - -from typing import Literal, Optional, Union - -from hopsworks_common.client import external, hopsworks - - -_client: Union[hopsworks.Client, external.Client, None] = None - - -def init( - client_type: Union[Literal["hopsworks"], Literal["external"]], - host: Optional[str] = None, - port: Optional[int] = None, - project: Optional[str] = None, - engine: Optional[str] = None, - region_name: Optional[str] = None, - secrets_store=None, - hostname_verification: Optional[bool] = None, - trust_store_path: Optional[str] = None, - cert_folder: Optional[str] = None, - api_key_file: Optional[str] = None, - api_key_value: Optional[str] = None, -) -> None: - global _client - if not _client: - if client_type == "hopsworks": - _client = hopsworks.Client() - elif client_type == "external": - _client = external.Client( - host, - port, - project, - engine, - region_name, - secrets_store, - hostname_verification, - trust_store_path, - cert_folder, - api_key_file, - api_key_value, - ) - - -def get_instance() -> Union[hopsworks.Client, external.Client]: - global _client - if _client: - return _client - raise Exception("Couldn't find client. Try reconnecting to Hopsworks.") - - -def stop() -> None: - global _client - if _client: - _client._close() - _client = None diff --git a/python/hopsworks_common/client/auth.py b/python/hopsworks_common/client/auth.py deleted file mode 100644 index f90b06cf4..000000000 --- a/python/hopsworks_common/client/auth.py +++ /dev/null @@ -1,52 +0,0 @@ -# -# Copyright 2022 Logical Clocks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import annotations - -import requests - - -class BearerAuth(requests.auth.AuthBase): - """Class to encapsulate a Bearer token.""" - - def __init__(self, token: str) -> None: - self._token = token.strip() - - def __call__(self, r: requests.Request) -> requests.Request: - r.headers["Authorization"] = "Bearer " + self._token - return r - - -class ApiKeyAuth(requests.auth.AuthBase): - """Class to encapsulate an API key.""" - - def __init__(self, token: str) -> None: - self._token = token.strip() - - def __call__(self, r: requests.Request) -> requests.Request: - r.headers["Authorization"] = "ApiKey " + self._token - return r - - -class OnlineStoreKeyAuth(requests.auth.AuthBase): - """Class to encapsulate an API key.""" - - def __init__(self, token): - self._token = token.strip() - - def __call__(self, r): - r.headers["X-API-KEY"] = self._token - return r diff --git a/python/hopsworks_common/client/base.py b/python/hopsworks_common/client/base.py deleted file mode 100644 index 7c7b4e602..000000000 --- a/python/hopsworks_common/client/base.py +++ /dev/null @@ -1,293 +0,0 @@ -# -# Copyright 2022 Logical Clocks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import annotations - -import base64 -import os -import textwrap -import time -from pathlib import Path - -import furl -import requests -import urllib3 -from hopsworks_common.client import auth, exceptions -from hopsworks_common.decorators import connected - - -try: - import jks -except ImportError: - pass - - -urllib3.disable_warnings(urllib3.exceptions.SecurityWarning) -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - - -class Client: - TOKEN_FILE = "token.jwt" - TOKEN_EXPIRED_RETRY_INTERVAL = 0.6 - TOKEN_EXPIRED_MAX_RETRIES = 10 - - APIKEY_FILE = "api.key" - REST_ENDPOINT = "REST_ENDPOINT" - DEFAULT_DATABRICKS_ROOT_VIRTUALENV_ENV = "DEFAULT_DATABRICKS_ROOT_VIRTUALENV_ENV" - HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST" - - def _get_verify(self, verify, trust_store_path): - """Get verification method for sending HTTP requests to Hopsworks. - - Credit to https://gist.github.com/gdamjan/55a8b9eec6cf7b771f92021d93b87b2c - - :param verify: perform hostname verification, 'true' or 'false' - :type verify: str - :param trust_store_path: path of the truststore locally if it was uploaded manually to - the external environment such as AWS Sagemaker - :type trust_store_path: str - :return: if verify is true and the truststore is provided, then return the trust store location - if verify is true but the truststore wasn't provided, then return true - if verify is false, then return false - :rtype: str or boolean - """ - if verify == "true": - if trust_store_path is not None: - return trust_store_path - else: - return True - - return False - - def _get_host_port_pair(self): - """ - Removes "http or https" from the rest endpoint and returns a list - [endpoint, port], where endpoint is on the format /path.. without http:// - - :return: a list [endpoint, port] - :rtype: list - """ - endpoint = self._base_url - if "http" in endpoint: - last_index = endpoint.rfind("/") - endpoint = endpoint[last_index + 1 :] - host, port = endpoint.split(":") - return host, port - - def _read_jwt(self): - """Retrieve jwt from local container.""" - return self._read_file(self.TOKEN_FILE) - - def _read_apikey(self): - """Retrieve apikey from local container.""" - return self._read_file(self.APIKEY_FILE) - - def _read_file(self, secret_file): - """Retrieve secret from local container.""" - with open(os.path.join(self._secrets_dir, secret_file), "r") as secret: - return secret.read() - - def _get_credentials(self, project_id): - """Makes a REST call to hopsworks for getting the project user certificates needed to connect to services such as Hive - - :param project_id: id of the project - :type project_id: int - :return: JSON response with credentials - :rtype: dict - """ - return self._send_request("GET", ["project", project_id, "credentials"]) - - def _write_pem_file(self, content: str, path: str) -> None: - with open(path, "w") as f: - f.write(content) - - @connected - def _send_request( - self, - method, - path_params, - query_params=None, - headers=None, - data=None, - stream=False, - files=None, - with_base_path_params=True, - ): - """Send REST request to Hopsworks. - - Uses the client it is executed from. Path parameters are url encoded automatically. - - :param method: 'GET', 'PUT' or 'POST' - :type method: str - :param path_params: a list of path params to build the query url from starting after - the api resource, for example `["project", 119, "featurestores", 67]`. - :type path_params: list - :param query_params: A dictionary of key/value pairs to be added as query parameters, - defaults to None - :type query_params: dict, optional - :param headers: Additional header information, defaults to None - :type headers: dict, optional - :param data: The payload as a python dictionary to be sent as json, defaults to None - :type data: dict, optional - :param stream: Set if response should be a stream, defaults to False - :type stream: boolean, optional - :param files: dictionary for multipart encoding upload - :type files: dict, optional - :raises RestAPIError: Raised when request wasn't correctly received, understood or accepted - :return: Response json - :rtype: dict - """ - f_url = furl.furl(self._base_url) - if with_base_path_params: - base_path_params = ["hopsworks-api", "api"] - f_url.path.segments = base_path_params + path_params - else: - f_url.path.segments = path_params - url = str(f_url) - - request = requests.Request( - method, - url=url, - headers=headers, - data=data, - params=query_params, - auth=self._auth, - files=files, - ) - - prepped = self._session.prepare_request(request) - response = self._session.send(prepped, verify=self._verify, stream=stream) - - if response.status_code == 401 and self.REST_ENDPOINT in os.environ: - # refresh token and retry request - only on hopsworks - response = self._retry_token_expired( - request, stream, self.TOKEN_EXPIRED_RETRY_INTERVAL, 1 - ) - - if response.status_code // 100 != 2: - raise exceptions.RestAPIError(url, response) - - if stream: - return response - else: - # handle different success response codes - if len(response.content) == 0: - return None - return response.json() - - def _retry_token_expired(self, request, stream, wait, retries): - """Refresh the JWT token and retry the request. Only on Hopsworks. - As the token might take a while to get refreshed. Keep trying - """ - # Sleep the waited time before re-issuing the request - time.sleep(wait) - - self._auth = auth.BearerAuth(self._read_jwt()) - # Update request with the new token - request.auth = self._auth - prepped = self._session.prepare_request(request) - response = self._session.send(prepped, verify=self._verify, stream=stream) - - if response.status_code == 401 and retries < self.TOKEN_EXPIRED_MAX_RETRIES: - # Try again. - return self._retry_token_expired(request, stream, wait * 2, retries + 1) - else: - # If the number of retries have expired, the _send_request method - # will throw an exception to the user as part of the status_code validation. - return response - - def _close(self): - """Closes a client. Can be implemented for clean up purposes, not mandatory.""" - self._connected = False - - def _write_pem( - self, keystore_path, keystore_pw, truststore_path, truststore_pw, prefix - ): - ks = jks.KeyStore.load(Path(keystore_path), keystore_pw, try_decrypt_keys=True) - ts = jks.KeyStore.load( - Path(truststore_path), truststore_pw, try_decrypt_keys=True - ) - - ca_chain_path = os.path.join("/tmp", f"{prefix}_ca_chain.pem") - self._write_ca_chain(ks, ts, ca_chain_path) - - client_cert_path = os.path.join("/tmp", f"{prefix}_client_cert.pem") - self._write_client_cert(ks, client_cert_path) - - client_key_path = os.path.join("/tmp", f"{prefix}_client_key.pem") - self._write_client_key(ks, client_key_path) - - return ca_chain_path, client_cert_path, client_key_path - - def _write_ca_chain(self, ks, ts, ca_chain_path): - """ - Converts JKS keystore and truststore file into ca chain PEM to be compatible with Python libraries - """ - ca_chain = "" - for store in [ks, ts]: - for _, c in store.certs.items(): - ca_chain = ca_chain + self._bytes_to_pem_str(c.cert, "CERTIFICATE") - - with Path(ca_chain_path).open("w") as f: - f.write(ca_chain) - - def _write_client_cert(self, ks, client_cert_path): - """ - Converts JKS keystore file into client cert PEM to be compatible with Python libraries - """ - client_cert = "" - for _, pk in ks.private_keys.items(): - for c in pk.cert_chain: - client_cert = client_cert + self._bytes_to_pem_str(c[1], "CERTIFICATE") - - with Path(client_cert_path).open("w") as f: - f.write(client_cert) - - def _write_client_key(self, ks, client_key_path): - """ - Converts JKS keystore file into client key PEM to be compatible with Python libraries - """ - client_key = "" - for _, pk in ks.private_keys.items(): - client_key = client_key + self._bytes_to_pem_str( - pk.pkey_pkcs8, "PRIVATE KEY" - ) - - with Path(client_key_path).open("w") as f: - f.write(client_key) - - def _bytes_to_pem_str(self, der_bytes, pem_type): - """ - Utility function for creating PEM files - - Args: - der_bytes: DER encoded bytes - pem_type: type of PEM, e.g Certificate, Private key, or RSA private key - - Returns: - PEM String for a DER-encoded certificate or private key - """ - pem_str = "" - pem_str = pem_str + "-----BEGIN {}-----".format(pem_type) + "\n" - pem_str = ( - pem_str - + "\r\n".join( - textwrap.wrap(base64.b64encode(der_bytes).decode("ascii"), 64) - ) - + "\n" - ) - pem_str = pem_str + "-----END {}-----".format(pem_type) + "\n" - return pem_str diff --git a/python/hopsworks_common/client/exceptions.py b/python/hopsworks_common/client/exceptions.py deleted file mode 100644 index 4e8ba9b08..000000000 --- a/python/hopsworks_common/client/exceptions.py +++ /dev/null @@ -1,143 +0,0 @@ -# -# Copyright 2022 Logical Clocks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import annotations - -from enum import Enum -from typing import Any, Union - -import requests - - -class RestAPIError(Exception): - """REST Exception encapsulating the response object and url.""" - - class FeatureStoreErrorCode(int, Enum): - FEATURE_GROUP_COMMIT_NOT_FOUND = 270227 - STATISTICS_NOT_FOUND = 270228 - - def __eq__(self, other: Union[int, Any]) -> bool: - if isinstance(other, int): - return self.value == other - if isinstance(other, self.__class__): - return self is other - return False - - def __init__(self, url: str, response: requests.Response) -> None: - try: - error_object = response.json() - if isinstance(error_object, str): - error_object = {"errorMsg": error_object} - except Exception: - error_object = {} - message = ( - "Metadata operation error: (url: {}). Server response: \n" - "HTTP code: {}, HTTP reason: {}, body: {}, error code: {}, error msg: {}, user " - "msg: {}".format( - url, - response.status_code, - response.reason, - response.content, - error_object.get("errorCode", ""), - error_object.get("errorMsg", ""), - error_object.get("usrMsg", ""), - ) - ) - super().__init__(message) - self.url = url - self.response = response - - -class UnknownSecretStorageError(Exception): - """This exception will be raised if an unused secrets storage is passed as a parameter.""" - - -class FeatureStoreException(Exception): - """Generic feature store exception""" - - -class VectorDatabaseException(Exception): - # reason - REQUESTED_K_TOO_LARGE = "REQUESTED_K_TOO_LARGE" - REQUESTED_NUM_RESULT_TOO_LARGE = "REQUESTED_NUM_RESULT_TOO_LARGE" - OTHERS = "OTHERS" - - # info - REQUESTED_K_TOO_LARGE_INFO_K = "k" - REQUESTED_NUM_RESULT_TOO_LARGE_INFO_N = "n" - - def __init__(self, reason: str, message: str, info: str) -> None: - super().__init__(message) - self._info = info - self._reason = reason - - @property - def reason(self) -> str: - return self._reason - - @property - def info(self) -> str: - return self._info - - -class DataValidationException(FeatureStoreException): - """Raised when data validation fails only when using "STRICT" validation ingestion policy.""" - - def __init__(self, message: str) -> None: - super().__init__(message) - - -class ExternalClientError(TypeError): - """Raised when external client cannot be initialized due to missing arguments.""" - - def __init__(self, missing_argument: str) -> None: - message = ( - "{0} cannot be of type NoneType, {0} is a non-optional " - "argument to connect to hopsworks from an external environment." - ).format(missing_argument) - super().__init__(message) - - -class GitException(Exception): - """Generic git exception""" - - -class JobException(Exception): - """Generic job exception""" - - -class EnvironmentException(Exception): - """Generic python environment exception""" - - -class KafkaException(Exception): - """Generic kafka exception""" - - -class DatasetException(Exception): - """Generic dataset exception""" - - -class ProjectException(Exception): - """Generic project exception""" - - -class OpenSearchException(Exception): - """Generic opensearch exception""" - - -class JobExecutionException(Exception): - """Generic job executions exception""" diff --git a/python/hopsworks_common/client/external.py b/python/hopsworks_common/client/external.py deleted file mode 100644 index c01045af8..000000000 --- a/python/hopsworks_common/client/external.py +++ /dev/null @@ -1,407 +0,0 @@ -# -# Copyright 2022 Logical Clocks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import annotations - -import base64 -import json -import logging -import os - -import boto3 -import requests -from hopsworks_common.client import auth, base, exceptions -from hopsworks_common.client.exceptions import FeatureStoreException - - -try: - from pyspark.sql import SparkSession -except ImportError: - pass - - -_logger = logging.getLogger(__name__) - - -class Client(base.Client): - DEFAULT_REGION = "default" - SECRETS_MANAGER = "secretsmanager" - PARAMETER_STORE = "parameterstore" - LOCAL_STORE = "local" - - def __init__( - self, - host, - port, - project, - engine, - region_name, - secrets_store, - hostname_verification, - trust_store_path, - cert_folder, - api_key_file, - api_key_value, - ): - """Initializes a client in an external environment such as AWS Sagemaker.""" - _logger.info("Initializing external client") - if not host: - raise exceptions.ExternalClientError("host") - - self._host = host - self._port = port - self._base_url = "https://" + self._host + ":" + str(self._port) - _logger.info("Base URL: %s", self._base_url) - self._project_name = project - if project is not None: - project_info = self._get_project_info(project) - self._project_id = str(project_info["projectId"]) - _logger.debug("Setting Project ID: %s", self._project_id) - else: - self._project_id = None - _logger.debug("Project name: %s", self._project_name) - self._region_name = region_name or self.DEFAULT_REGION - _logger.debug("Region name: %s", self._region_name) - - if api_key_value is not None: - _logger.debug("Using provided API key value") - api_key = api_key_value - else: - _logger.debug("Querying secrets store for API key") - if secrets_store is None: - secrets_store = self.LOCAL_STORE - api_key = self._get_secret(secrets_store, "api-key", api_key_file) - - _logger.debug("Using api key to setup header authentification") - self._auth = auth.ApiKeyAuth(api_key) - - _logger.debug("Setting up requests session") - self._session = requests.session() - self._connected = True - - self._verify = self._get_verify(self._host, trust_store_path) - _logger.debug("Verify: %s", self._verify) - - self._cert_key = None - self._cert_folder_base = cert_folder - self._cert_folder = None - - if project is None: - return - - if engine == "python": - self.download_certs(project) - - elif engine == "spark": - # When using the Spark engine with metastore connection, the certificates - # are needed when the application starts (before user code is run) - # So in this case, we can't materialize the certificates on the fly. - _logger.debug("Running in Spark environment, initializing Spark session") - _spark_session = SparkSession.builder.enableHiveSupport().getOrCreate() - - self._validate_spark_configuration(_spark_session) - with open( - _spark_session.conf.get("spark.hadoop.hops.ssl.keystores.passwd.name"), - "r", - ) as f: - self._cert_key = f.read() - - self._trust_store_path = _spark_session.conf.get( - "spark.hadoop.hops.ssl.trustore.name" - ) - self._key_store_path = _spark_session.conf.get( - "spark.hadoop.hops.ssl.keystore.name" - ) - - elif engine == "spark-no-metastore": - _logger.debug( - "Running in Spark environment with no metastore, initializing Spark session" - ) - _spark_session = SparkSession.builder.getOrCreate() - self.download_certs(project) - - # Set credentials location in the Spark configuration - # Set other options in the Spark configuration - configuration_dict = { - "hops.ssl.trustore.name": self._trust_store_path, - "hops.ssl.keystore.name": self._key_store_path, - "hops.ssl.keystores.passwd.name": self._cert_key_path, - "fs.permissions.umask-mode": "0002", - "fs.hopsfs.impl": "io.hops.hopsfs.client.HopsFileSystem", - "hops.rpc.socket.factory.class.default": "io.hops.hadoop.shaded.org.apache.hadoop.net.HopsSSLSocketFactory", - "client.rpc.ssl.enabled.protocol": "TLSv1.2", - "hops.ssl.hostname.verifier": "ALLOW_ALL", - "hops.ipc.server.ssl.enabled": "true", - } - - for conf_key, conf_value in configuration_dict.items(): - _spark_session._jsc.hadoopConfiguration().set(conf_key, conf_value) - - def download_certs(self, project): - res = self._materialize_certs(self, project) - self._write_pem_file(res["caChain"], self._get_ca_chain_path()) - self._write_pem_file(res["clientCert"], self._get_client_cert_path()) - self._write_pem_file(res["clientKey"], self._get_client_key_path()) - return res - - def _materialize_certs(self, project): - if project != self._project_name: - self._project_name = project - _logger.debug("Project name: %s", self._project_name) - project_info = self._get_project_info(project) - self._project_id = str(project_info["projectId"]) - _logger.debug("Setting Project ID: %s", self._project_id) - - self._cert_folder = os.path.join(self._cert_folder_base, self._host, project) - self._trust_store_path = os.path.join(self._cert_folder, "trustStore.jks") - self._key_store_path = os.path.join(self._cert_folder, "keyStore.jks") - - if os.path.exists(self._cert_folder): - _logger.debug( - f"Running in Python environment, reading certificates from certificates folder {self._cert_folder_base}" - ) - _logger.debug("Found certificates: %s", os.listdir(self._cert_folder_base)) - else: - _logger.debug( - f"Running in Python environment, creating certificates folder {self._cert_folder_base}" - ) - os.makedirs(self._cert_folder, exist_ok=True) - - credentials = self._get_credentials(self._project_id) - self._write_b64_cert_to_bytes( - str(credentials["kStore"]), - path=self._get_jks_key_store_path(), - ) - self._write_b64_cert_to_bytes( - str(credentials["tStore"]), - path=self._get_jks_trust_store_path(), - ) - - self._cert_key = str(credentials["password"]) - self._cert_key_path = os.path.join(self._cert_folder, "material_passwd") - with open(self._cert_key_path, "w") as f: - f.write(str(credentials["password"])) - - # Return the credentials object for the Python engine to materialize the pem files. - return credentials - - def _validate_spark_configuration(self, _spark_session): - exception_text = "Spark is misconfigured for communication with Hopsworks, missing or invalid property: " - - configuration_dict = { - "spark.hadoop.hops.ssl.trustore.name": None, - "spark.hadoop.hops.rpc.socket.factory.class.default": "io.hops.hadoop.shaded.org.apache.hadoop.net.HopsSSLSocketFactory", - "spark.serializer": "org.apache.spark.serializer.KryoSerializer", - "spark.hadoop.hops.ssl.hostname.verifier": "ALLOW_ALL", - "spark.hadoop.hops.ssl.keystore.name": None, - "spark.hadoop.fs.hopsfs.impl": "io.hops.hopsfs.client.HopsFileSystem", - "spark.hadoop.hops.ssl.keystores.passwd.name": None, - "spark.hadoop.hops.ipc.server.ssl.enabled": "true", - "spark.hadoop.client.rpc.ssl.enabled.protocol": "TLSv1.2", - "spark.hadoop.hive.metastore.uris": None, - "spark.sql.hive.metastore.jars": None, - } - _logger.debug("Configuration dict: %s", configuration_dict) - - for key, value in configuration_dict.items(): - _logger.debug("Validating key: %s", key) - if not ( - _spark_session.conf.get(key, "not_found") != "not_found" - and (value is None or _spark_session.conf.get(key, None) == value) - ): - raise FeatureStoreException(exception_text + key) - - def _close(self): - """Closes a client and deletes certificates.""" - _logger.info("Closing external client and cleaning up certificates.") - self._connected = False - if self._cert_folder is None: - _logger.debug("No certificates to clean up.") - # On external Spark clients (Databricks, Spark Cluster), - # certificates need to be provided before the Spark application starts. - return - - # Clean up only on AWS - _logger.debug("Cleaning up certificates. AWS only.") - self._cleanup_file(self._get_jks_key_store_path()) - self._cleanup_file(self._get_jks_trust_store_path()) - self._cleanup_file(os.path.join(self._cert_folder, "material_passwd")) - self._cleanup_file(self._get_ca_chain_path()) - self._cleanup_file(self._get_client_cert_path()) - self._cleanup_file(self._get_client_key_path()) - - try: - # delete project level - os.rmdir(self._cert_folder) - # delete host level - os.rmdir(os.path.dirname(self._cert_folder)) - # on AWS base dir will be empty, and can be deleted otherwise raises OSError - os.rmdir(self._cert_folder_base) - except OSError: - pass - - self._cert_folder = None - - def _get_jks_trust_store_path(self): - _logger.debug("Getting trust store path: %s", self._trust_store_path) - return self._trust_store_path - - def _get_jks_key_store_path(self): - _logger.debug("Getting key store path: %s", self._key_store_path) - return self._key_store_path - - def _get_ca_chain_path(self, project_name=None) -> str: - if project_name is None: - project_name = self._project_name - path = os.path.join( - self._cert_folder_base, self._host, project_name, "ca_chain.pem" - ) - _logger.debug(f"Getting ca chain path {path}") - return path - - def _get_client_cert_path(self, project_name=None) -> str: - if project_name is None: - project_name = self._project_name - path = os.path.join( - self._cert_folder_base, self._host, project_name, "client_cert.pem" - ) - _logger.debug(f"Getting client cert path {path}") - return path - - def _get_client_key_path(self, project_name=None) -> str: - if project_name is None: - project_name = self._project_name - path = os.path.join( - self._cert_folder_base, self._host, project_name, "client_key.pem" - ) - _logger.debug(f"Getting client key path {path}") - return path - - def _get_secret(self, secrets_store, secret_key=None, api_key_file=None): - """Returns secret value from the AWS Secrets Manager or Parameter Store. - - :param secrets_store: the underlying secrets storage to be used, e.g. `secretsmanager` or `parameterstore` - :type secrets_store: str - :param secret_key: key for the secret value, e.g. `api-key`, `cert-key`, `trust-store`, `key-store`, defaults to None - :type secret_key: str, optional - :param api_key_file: path to a file containing an api key, defaults to None - :type api_key_file: str optional - :raises hsfs.client.exceptions.ExternalClientError: `api_key_file` needs to be set for local mode - :raises hsfs.client.exceptions.UnknownSecretStorageError: Provided secrets storage not supported - :return: secret - :rtype: str - """ - _logger.debug(f"Querying secrets store {secrets_store} for secret {secret_key}") - if secrets_store == self.SECRETS_MANAGER: - return self._query_secrets_manager(secret_key) - elif secrets_store == self.PARAMETER_STORE: - return self._query_parameter_store(secret_key) - elif secrets_store == self.LOCAL_STORE: - if not api_key_file: - raise exceptions.ExternalClientError( - "api_key_file needs to be set for local mode" - ) - _logger.debug(f"Reading api key from {api_key_file}") - with open(api_key_file) as f: - return f.readline().strip() - else: - raise exceptions.UnknownSecretStorageError( - "Secrets storage " + secrets_store + " is not supported." - ) - - def _query_secrets_manager(self, secret_key): - _logger.debug("Querying secrets manager for secret key: %s", secret_key) - secret_name = "hopsworks/role/" + self._assumed_role() - args = {"service_name": "secretsmanager"} - region_name = self._get_region() - if region_name: - args["region_name"] = region_name - client = boto3.client(**args) - get_secret_value_response = client.get_secret_value(SecretId=secret_name) - return json.loads(get_secret_value_response["SecretString"])[secret_key] - - def _assumed_role(self): - _logger.debug("Getting assumed role") - client = boto3.client("sts") - response = client.get_caller_identity() - # arns for assumed roles in SageMaker follow the following schema - # arn:aws:sts::123456789012:assumed-role/my-role-name/my-role-session-name - local_identifier = response["Arn"].split(":")[-1].split("/") - if len(local_identifier) != 3 or local_identifier[0] != "assumed-role": - raise Exception( - "Failed to extract assumed role from arn: " + response["Arn"] - ) - return local_identifier[1] - - def _get_region(self): - if self._region_name != self.DEFAULT_REGION: - _logger.debug(f"Region name is not default, returning {self._region_name}") - return self._region_name - else: - _logger.debug("Region name is default, returning None") - return None - - def _query_parameter_store(self, secret_key): - _logger.debug("Querying parameter store for secret key: %s", secret_key) - args = {"service_name": "ssm"} - region_name = self._get_region() - if region_name: - args["region_name"] = region_name - client = boto3.client(**args) - name = "/hopsworks/role/" + self._assumed_role() + "/type/" + secret_key - return client.get_parameter(Name=name, WithDecryption=True)["Parameter"][ - "Value" - ] - - def _get_project_info(self, project_name): - """Makes a REST call to hopsworks to get all metadata of a project for the provided project. - - :param project_name: the name of the project - :type project_name: str - :return: JSON response with project info - :rtype: dict - """ - _logger.debug("Getting project info for project: %s", project_name) - return self._send_request("GET", ["project", "getProjectInfo", project_name]) - - def _write_b64_cert_to_bytes(self, b64_string, path): - """Converts b64 encoded certificate to bytes file . - - :param b64_string: b64 encoded string of certificate - :type b64_string: str - :param path: path where file is saved, including file name. e.g. /path/key-store.jks - :type path: str - """ - _logger.debug(f"Writing b64 encoded certificate to {path}") - with open(path, "wb") as f: - cert_b64 = base64.b64decode(b64_string) - f.write(cert_b64) - - def _cleanup_file(self, file_path): - """Removes local files with `file_path`.""" - _logger.debug(f"Cleaning up file {file_path}") - try: - os.remove(file_path) - except OSError: - pass - - def replace_public_host(self, url): - """no need to replace as we are already in external client""" - return url - - @property - def host(self): - return self._host diff --git a/python/hopsworks_common/client/hopsworks.py b/python/hopsworks_common/client/hopsworks.py deleted file mode 100644 index ddc81fc20..000000000 --- a/python/hopsworks_common/client/hopsworks.py +++ /dev/null @@ -1,236 +0,0 @@ -# -# Copyright 2022 Logical Clocks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import base64 -import os -import textwrap -from pathlib import Path - -import requests -from hopsworks_common.client import auth, base - - -try: - import jks -except ImportError: - pass - - -class Client(base.Client): - REQUESTS_VERIFY = "REQUESTS_VERIFY" - DOMAIN_CA_TRUSTSTORE_PEM = "DOMAIN_CA_TRUSTSTORE_PEM" - PROJECT_ID = "HOPSWORKS_PROJECT_ID" - PROJECT_NAME = "HOPSWORKS_PROJECT_NAME" - HADOOP_USER_NAME = "HADOOP_USER_NAME" - MATERIAL_DIRECTORY = "MATERIAL_DIRECTORY" - HDFS_USER = "HDFS_USER" - T_CERTIFICATE = "t_certificate" - K_CERTIFICATE = "k_certificate" - TRUSTSTORE_SUFFIX = "__tstore.jks" - KEYSTORE_SUFFIX = "__kstore.jks" - PEM_CA_CHAIN = "ca_chain.pem" - CERT_KEY_SUFFIX = "__cert.key" - MATERIAL_PWD = "material_passwd" - SECRETS_DIR = "SECRETS_DIR" - - def __init__(self): - """Initializes a client being run from a job/notebook directly on Hopsworks.""" - self._base_url = self._get_hopsworks_rest_endpoint() - self._host, self._port = self._get_host_port_pair() - self._secrets_dir = ( - os.environ[self.SECRETS_DIR] if self.SECRETS_DIR in os.environ else "" - ) - self._cert_key = self._get_cert_pw() - trust_store_path = self._get_trust_store_path() - hostname_verification = ( - os.environ[self.REQUESTS_VERIFY] - if self.REQUESTS_VERIFY in os.environ - else "true" - ) - self._project_id = os.environ[self.PROJECT_ID] - self._project_name = self._project_name() - try: - self._auth = auth.BearerAuth(self._read_jwt()) - except FileNotFoundError: - self._auth = auth.ApiKeyAuth(self._read_apikey()) - self._verify = self._get_verify(hostname_verification, trust_store_path) - self._session = requests.session() - - self._connected = True - - credentials = self._get_credentials(self._project_id) - - self._write_pem_file(credentials["caChain"], self._get_ca_chain_path()) - self._write_pem_file(credentials["clientCert"], self._get_client_cert_path()) - self._write_pem_file(credentials["clientKey"], self._get_client_key_path()) - - def _get_hopsworks_rest_endpoint(self): - """Get the hopsworks REST endpoint for making requests to the REST API.""" - return os.environ[self.REST_ENDPOINT] - - def _get_trust_store_path(self): - """Convert truststore from jks to pem and return the location""" - ca_chain_path = Path(self.PEM_CA_CHAIN) - if not ca_chain_path.exists(): - ks = jks.KeyStore.load( - self._get_jks_key_store_path(), self._cert_key, try_decrypt_keys=True - ) - ts = jks.KeyStore.load( - self._get_jks_trust_store_path(), self._cert_key, try_decrypt_keys=True - ) - self._write_ca_chain(ks, ts, ca_chain_path) - return str(ca_chain_path) - - def _get_ca_chain_path(self, project_name=None) -> str: - return os.path.join("/tmp", "ca_chain.pem") - - def _get_client_cert_path(self, project_name=None) -> str: - return os.path.join("/tmp", "client_cert.pem") - - def _get_client_key_path(self, project_name=None) -> str: - return os.path.join("/tmp", "client_key.pem") - - def _write_ca_chain(self, ca_chain_path): - """ - Converts JKS trustore file into PEM to be compatible with Python libraries - """ - keystore_pw = self._cert_key - keystore_ca_cert = self._convert_jks_to_pem( - self._get_jks_key_store_path(), keystore_pw - ) - truststore_ca_cert = self._convert_jks_to_pem( - self._get_jks_trust_store_path(), keystore_pw - ) - - with ca_chain_path.open("w") as f: - f.write(keystore_ca_cert + truststore_ca_cert) - - def _convert_jks_to_pem(self, jks_path, keystore_pw): - """ - Converts a keystore JKS that contains client private key, - client certificate and CA certificate that was used to - sign the certificate to PEM format and returns the CA certificate. - Args: - :jks_path: path to the JKS file - :pw: password for decrypting the JKS file - Returns: - strings: (ca_cert) - """ - # load the keystore and decrypt it with password - ks = jks.KeyStore.load(jks_path, keystore_pw, try_decrypt_keys=True) - ca_certs = "" - - # Convert CA Certificates into PEM format and append to string - for _alias, c in ks.certs.items(): - ca_certs = ca_certs + self._bytes_to_pem_str(c.cert, "CERTIFICATE") - return ca_certs - - def _bytes_to_pem_str(self, der_bytes, pem_type): - """ - Utility function for creating PEM files - - Args: - der_bytes: DER encoded bytes - pem_type: type of PEM, e.g Certificate, Private key, or RSA private key - - Returns: - PEM String for a DER-encoded certificate or private key - """ - pem_str = "" - pem_str = pem_str + "-----BEGIN {}-----".format(pem_type) + "\n" - pem_str = ( - pem_str - + "\r\n".join( - textwrap.wrap(base64.b64encode(der_bytes).decode("ascii"), 64) - ) - + "\n" - ) - pem_str = pem_str + "-----END {}-----".format(pem_type) + "\n" - return pem_str - - def _get_jks_trust_store_path(self): - """ - Get truststore location - - Returns: - truststore location - """ - t_certificate = Path(self.T_CERTIFICATE) - if t_certificate.exists(): - return str(t_certificate) - else: - username = os.environ[self.HADOOP_USER_NAME] - material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) - return str(material_directory.joinpath(username + self.TRUSTSTORE_SUFFIX)) - - def _get_jks_key_store_path(self): - """ - Get keystore location - - Returns: - keystore location - """ - k_certificate = Path(self.K_CERTIFICATE) - if k_certificate.exists(): - return str(k_certificate) - else: - username = os.environ[self.HADOOP_USER_NAME] - material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) - return str(material_directory.joinpath(username + self.KEYSTORE_SUFFIX)) - - def _project_name(self): - try: - return os.environ[self.PROJECT_NAME] - except KeyError: - pass - - hops_user = self._project_user() - # project users have username project__user: - hops_user_split = hops_user.split("__") - project = hops_user_split[0] - return project - - def _project_user(self): - try: - hops_user = os.environ[self.HADOOP_USER_NAME] - except KeyError: - hops_user = os.environ[self.HDFS_USER] - return hops_user - - def _get_cert_pw(self): - """ - Get keystore password from local container - - Returns: - Certificate password - """ - pwd_path = Path(self.MATERIAL_PWD) - if not pwd_path.exists(): - username = os.environ[self.HADOOP_USER_NAME] - material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) - pwd_path = material_directory.joinpath(username + self.CERT_KEY_SUFFIX) - - with pwd_path.open() as f: - return f.read() - - def replace_public_host(self, url): - """replace hostname to public hostname set in HOPSWORKS_PUBLIC_HOST""" - ui_url = url._replace(netloc=os.environ[self.HOPSWORKS_PUBLIC_HOST]) - return ui_url - - @property - def host(self): - return self._host diff --git a/python/hopsworks_common/client/online_store_rest_client.py b/python/hopsworks_common/client/online_store_rest_client.py deleted file mode 100644 index 03d77471c..000000000 --- a/python/hopsworks_common/client/online_store_rest_client.py +++ /dev/null @@ -1,385 +0,0 @@ -# -# Copyright 2024 Hopsworks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import annotations - -import logging -from typing import Any, Dict, List, Optional, Union -from warnings import warn - -import requests -import requests.adapters -from furl import furl -from hopsworks_common import client -from hopsworks_common.client.exceptions import FeatureStoreException -from hopsworks_common.core import variable_api - - -_logger = logging.getLogger(__name__) - -_online_store_rest_client = None - - -def init_or_reset_online_store_rest_client( - transport: Optional[ - Union[requests.adapters.HTTPAdapter, requests.adapters.BaseAdapter] - ] = None, - optional_config: Optional[Dict[str, Any]] = None, - reset_client: bool = False, -): - global _online_store_rest_client - if not _online_store_rest_client: - _online_store_rest_client = OnlineStoreRestClientSingleton( - transport=transport, optional_config=optional_config - ) - elif reset_client: - _online_store_rest_client.reset_client( - transport=transport, optional_config=optional_config - ) - else: - _logger.warning( - "Online Store Rest Client is already initialised. To reset connection or/and override configuration, " - + "use reset_online_store_rest_client flag.", - stacklevel=2, - ) - - -def get_instance() -> OnlineStoreRestClientSingleton: - global _online_store_rest_client - if _online_store_rest_client is None: - _logger.warning( - "Online Store Rest Client is not initialised. Initialising with default configuration." - ) - _online_store_rest_client = OnlineStoreRestClientSingleton() - _logger.debug("Accessing global Online Store Rest Client instance.") - return _online_store_rest_client - - -class OnlineStoreRestClientSingleton: - HOST = "host" - PORT = "port" - VERIFY_CERTS = "verify_certs" - USE_SSL = "use_ssl" - CA_CERTS = "ca_certs" - HTTP_AUTHORIZATION = "http_authorization" - TIMEOUT = "timeout" - SERVER_API_VERSION = "server_api_version" - API_KEY = "api_key" - _DEFAULT_ONLINE_STORE_REST_CLIENT_PORT = 4406 - _DEFAULT_ONLINE_STORE_REST_CLIENT_TIMEOUT_SECOND = 2 - _DEFAULT_ONLINE_STORE_REST_CLIENT_VERIFY_CERTS = True - _DEFAULT_ONLINE_STORE_REST_CLIENT_USE_SSL = True - _DEFAULT_ONLINE_STORE_REST_CLIENT_SERVER_API_VERSION = "0.1.0" - _DEFAULT_ONLINE_STORE_REST_CLIENT_HTTP_AUTHORIZATION = "X-API-KEY" - - def __init__( - self, - transport: Optional[ - Union[requests.adapaters.HTTPadapter, requests.adapters.BaseAdapter] - ] = None, - optional_config: Optional[Dict[str, Any]] = None, - ): - _logger.debug( - f"Initialising Online Store Rest Client {'with optional configuration' if optional_config else ''}." - ) - if optional_config: - _logger.debug(f"Optional Config: {optional_config!r}") - self._check_hopsworks_connection() - self.variable_api = variable_api.VariableApi() - self._auth: client.auth.OnlineStoreKeyAuth - self._session: requests.Session - self._current_config: Dict[str, Any] - self._base_url: furl - self._setup_rest_client( - transport=transport, - optional_config=optional_config, - use_current_config=False, - ) - self.is_connected() - - def reset_client( - self, - transport: Optional[ - Union[requests.adapters.HttpAdapter, requests.adapters.BaseAdapter] - ] = None, - optional_config: Optional[Dict[str, Any]] = None, - ): - _logger.debug( - f"Resetting Online Store Rest Client {'with optional configuration' if optional_config else ''}." - ) - if optional_config: - _logger.debug(f"Optional Config: {optional_config}") - self._check_hopsworks_connection() - if hasattr(self, "_session") and self._session: - _logger.debug("Closing existing session.") - self._session.close() - delattr(self, "_session") - self._setup_rest_client( - transport=transport, - optional_config=optional_config, - use_current_config=False if optional_config else True, - ) - - def _setup_rest_client( - self, - transport: Optional[ - Union[requests.adapters.HttpAdapter, requests.adapters.BaseAdapter] - ] = None, - optional_config: Optional[Dict[str, Any]] = None, - use_current_config: bool = True, - ): - _logger.debug("Setting up Online Store Rest Client.") - if optional_config and not isinstance(optional_config, dict): - raise ValueError( - "optional_config must be a dictionary. See documentation for allowed keys and values." - ) - _logger.debug("Optional Config: %s", optional_config) - if not use_current_config: - _logger.debug( - "Retrieving default configuration for Online Store REST Client." - ) - self._current_config = self._get_default_client_config() - if optional_config: - _logger.debug( - "Updating default configuration with provided optional configuration." - ) - self._current_config.update(optional_config) - - self._set_auth(optional_config) - if not hasattr(self, "_session") or not self._session: - _logger.debug("Initialising new requests session.") - self._session = requests.Session() - else: - raise ValueError( - "Use the init_or_reset_online_store_connection method with reset_connection flag set " - + "to True to reset the online_store_client_connection" - ) - if transport is not None: - _logger.debug("Setting custom transport adapter.") - self._session.mount("https://", transport) - self._session.mount("http://", transport) - - if not self._current_config[self.VERIFY_CERTS]: - _logger.warning( - "Disabling SSL certificate verification. This is not recommended for production environments." - ) - self._session.verify = False - else: - _logger.debug( - f"Setting SSL certificate verification using CA Certs path: {self._current_config[self.CA_CERTS]}" - ) - self._session.verify = self._current_config[self.CA_CERTS] - - # Set base_url - scheme = "https" if self._current_config[self.USE_SSL] else "http" - self._base_url = furl( - f"{scheme}://{self._current_config[self.HOST]}:{self._current_config[self.PORT]}/{self._current_config[self.SERVER_API_VERSION]}" - ) - - assert ( - self._session is not None - ), "Online Store REST Client failed to initialise." - assert ( - self._auth is not None - ), "Online Store REST Client Authentication failed to initialise. Check API Key." - assert ( - self._base_url is not None - ), "Online Store REST Client Base URL failed to initialise. Check host and port parameters." - assert ( - self._current_config is not None - ), "Online Store REST Client Configuration failed to initialise." - - def _get_default_client_config(self) -> Dict[str, Any]: - _logger.debug("Retrieving default configuration for Online Store REST Client.") - default_config = self._get_default_static_parameters_config() - default_config.update(self._get_default_dynamic_parameters_config()) - return default_config - - def _get_default_static_parameters_config(self) -> Dict[str, Any]: - _logger.debug( - "Retrieving default static configuration for Online Store REST Client." - ) - return { - self.TIMEOUT: self._DEFAULT_ONLINE_STORE_REST_CLIENT_TIMEOUT_SECOND, - self.VERIFY_CERTS: self._DEFAULT_ONLINE_STORE_REST_CLIENT_VERIFY_CERTS, - self.USE_SSL: self._DEFAULT_ONLINE_STORE_REST_CLIENT_USE_SSL, - self.SERVER_API_VERSION: self._DEFAULT_ONLINE_STORE_REST_CLIENT_SERVER_API_VERSION, - self.HTTP_AUTHORIZATION: self._DEFAULT_ONLINE_STORE_REST_CLIENT_HTTP_AUTHORIZATION, - } - - def _get_default_dynamic_parameters_config( - self, - ) -> Dict[str, Any]: - _logger.debug( - "Retrieving default dynamic configuration for Online Store REST Client." - ) - url = furl(self._get_rondb_rest_server_endpoint()) - _logger.debug(f"Default RonDB Rest Server host and port: {url.host}:{url.port}") - _logger.debug( - f"Using CA Certs from Hopsworks Client: {client.get_instance()._get_ca_chain_path()}" - ) - return { - self.HOST: url.host, - self.PORT: url.port, - self.CA_CERTS: client.get_instance()._get_ca_chain_path(), - } - - def _get_rondb_rest_server_endpoint(self) -> str: - """Retrieve RonDB Rest Server endpoint based on whether the client is running internally or externally. - - If the client is running externally, the endpoint is retrieved via the loadbalancer. - If the client is running internally, the endpoint is retrieved via (consul) service discovery. - The default port for the RonDB Rest Server is 4406 and always used unless specifying a different port - in the configuration. - - Returns: - str: RonDB Rest Server endpoint with default port. - """ - if client.get_instance()._is_external(): - _logger.debug( - "External Online Store REST Client : Retrieving RonDB Rest Server endpoint via loadbalancer." - ) - external_domain = self.variable_api.get_loadbalancer_external_domain() - if external_domain == "": - _logger.debug( - "External Online Store REST Client : Loadbalancer external domain is not set. Using client host as endpoint." - ) - external_domain = client.get_instance().host - default_url = f"https://{external_domain}:{self._DEFAULT_ONLINE_STORE_REST_CLIENT_PORT}" - _logger.debug( - f"External Online Store REST Client : Default RonDB Rest Server endpoint: {default_url}" - ) - return default_url - else: - _logger.debug( - "Internal Online Store REST Client : Retrieving RonDB Rest Server endpoint via service discovery." - ) - service_discovery_domain = self.variable_api.get_service_discovery_domain() - if service_discovery_domain == "": - raise FeatureStoreException("Service discovery domain is not set.") - default_url = f"https://rdrs.service.{service_discovery_domain}:{self._DEFAULT_ONLINE_STORE_REST_CLIENT_PORT}" - _logger.debug( - f"Internal Online Store REST Client : Default RonDB Rest Server endpoint: {default_url}" - ) - return default_url - - def send_request( - self, - method: str, - path_params: List[str], - headers: Optional[Dict[str, Any]] = None, - data: Optional[str] = None, - ) -> requests.Response: - url = self._base_url.copy() - url.path.segments.extend(path_params) - _logger.debug(f"Sending {method} request to {url.url}.") - _logger.debug(f"Provided Data: {data}") - _logger.debug(f"Provided Headers: {headers}") - prepped_request = self._session.prepare_request( - requests.Request( - method, url=url.url, headers=headers, data=data, auth=self.auth - ) - ) - timeout = self._current_config[self.TIMEOUT] - return self._session.send( - prepped_request, - # compatibility with 3.7 - timeout=timeout if timeout < 500 else timeout / 1000, - ) - - def _check_hopsworks_connection(self) -> None: - _logger.debug("Checking Hopsworks connection.") - assert ( - client.get_instance() is not None and client.get_instance()._connected - ), """Hopsworks Client is not connected. Please connect to Hopsworks cluster - via hopsworks.login or hsfs.connection before initialising the Online Store REST Client. - """ - _logger.debug("Hopsworks connection is active.") - - def _set_auth(self, optional_config: Optional[Dict[str, Any]] = None) -> None: - """Set authentication object for the Online Store REST Client. - - RonDB Rest Server uses Hopsworks Api Key to authenticate requests via the X-API-KEY header by default. - The api key determines the permissions of the user making the request for access to a given Feature Store. - """ - _logger.debug("Setting authentication for Online Store REST Client.") - if client.get_instance()._is_external(): - assert hasattr( - client.get_instance()._auth, "_token" - ), "External client must use API Key authentication. Contact your system administrator." - _logger.debug( - "External Online Store REST Client : Setting authentication using Hopsworks Client API Key." - ) - self._auth = client.auth.OnlineStoreKeyAuth( - client.get_instance()._auth._token - ) - elif isinstance(optional_config, dict) and optional_config.get( - self.API_KEY, False - ): - _logger.debug( - "Setting authentication using provided API Key from optional configuration." - ) - self._auth = client.auth.OnlineStoreKeyAuth(optional_config[self.API_KEY]) - elif hasattr(self, "_auth") and self._auth is not None: - _logger.debug( - "Authentication for Online Store REST Client is already set. Using existing authentication api key." - ) - else: - raise FeatureStoreException( - "RonDB Rest Server uses Hopsworks Api Key to authenticate request." - + f"Provide a configuration with the {self.API_KEY} key." - ) - - def is_connected(self): - """If Online Store Rest Client is initialised, ping RonDB Rest Server to ensure connection is active.""" - if self._session is None: - _logger.debug( - "Checking Online Store REST Client is connected. Session is not initialised." - ) - raise FeatureStoreException("Online Store REST Client is not initialised.") - - _logger.debug( - "Checking Online Store REST Client is connected. Pinging RonDB Rest Server." - ) - if not self.send_request("GET", ["ping"]): - warn("Ping failed, RonDB Rest Server is not reachable.", stacklevel=2) - return False - return True - - @property - def session(self) -> requests.Session: - """Requests session object used to send requests to the Online Store REST API.""" - return self._session - - @property - def base_url(self) -> furl: - """Base URL for the Online Store REST API. - - This the url of the RonDB REST Server and should not be confused with the Opensearch Vector DB which also serves as an Online Store for features belonging to Feature Group containing embeddings.""" - return self._base_url - - @property - def current_config(self) -> Dict[str, Any]: - """Current configuration of the Online Store REST Client.""" - return self._current_config - - @property - def auth(self) -> "client.auth.OnlineStoreKeyAuth": - """Authentication object used to authenticate requests to the Online Store REST API. - - Extends the requests.auth.AuthBase class. - """ - return self._auth diff --git a/python/hopsworks_common/core/constants.py b/python/hopsworks_common/core/constants.py deleted file mode 100644 index 4e522de6a..000000000 --- a/python/hopsworks_common/core/constants.py +++ /dev/null @@ -1,51 +0,0 @@ -# -# Copyright 2024 Hopsworks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import importlib.util - - -# Avro -HAS_FAST_AVRO: bool = importlib.util.find_spec("fastavro") is not None -HAS_AVRO: bool = importlib.util.find_spec("avro") is not None - -# Confluent Kafka -HAS_CONFLUENT_KAFKA: bool = importlib.util.find_spec("confluent_kafka") is not None -confluent_kafka_not_installed_message = ( - "Confluent Kafka package not found. " - "If you want to use Kafka with Hopsworks you can install the corresponding extras " - """`pip install hopsworks[python]` or `pip install "hopsworks[python]"` if using zsh. """ - "You can also install confluent-kafka directly in your environment e.g `pip install confluent-kafka`. " - "You will need to restart your kernel if applicable." -) -# Data Validation / Great Expectations -HAS_GREAT_EXPECTATIONS: bool = ( - importlib.util.find_spec("great_expectations") is not None -) -great_expectations_not_installed_message = ( - "Great Expectations package not found. " - "If you want to use data validation with Hopsworks you can install the corresponding extras " - """`pip install hopsworks[great_expectations]` or `pip install "hopsworks[great_expectations]"` if using zsh. """ - "You can also install great-expectations directly in your environment e.g `pip install great-expectations`. " - "You will need to restart your kernel if applicable." -) -initialise_expectation_suite_for_single_expectation_api_message = "Initialize Expectation Suite by attaching to a Feature Group to enable single expectation API" - -# Numpy -HAS_NUMPY: bool = importlib.util.find_spec("numpy") is not None - -# SQL packages -HAS_SQLALCHEMY: bool = importlib.util.find_spec("sqlalchemy") is not None -HAS_AIOMYSQL: bool = importlib.util.find_spec("aiomysql") is not None diff --git a/python/hopsworks_common/core/variable_api.py b/python/hopsworks_common/core/variable_api.py deleted file mode 100644 index 7b3c74575..000000000 --- a/python/hopsworks_common/core/variable_api.py +++ /dev/null @@ -1,117 +0,0 @@ -# -# Copyright 2022 Hopsworks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import annotations - -import re -from typing import Optional, Tuple - -from hopsworks_common import client -from hopsworks_common.client.exceptions import RestAPIError - - -class VariableApi: - def __init__(self): - pass - - def get_variable(self, variable: str): - """Get the configured value of a variable. - - # Arguments - vairable: Name of the variable. - # Returns - The vairable's value - # Raises - `RestAPIError`: If unable to get the variable - """ - - _client = client.get_instance() - - path_params = ["variables", variable] - domain = _client._send_request("GET", path_params) - - return domain["successMessage"] - - def get_version(self, software: str) -> Optional[str]: - """Get version of a software component. - - # Arguments - software: Name of the software. - # Returns - The software's version, if the software is available, otherwise `None`. - # Raises - `RestAPIError`: If unable to get the version - """ - - _client = client.get_instance() - - path_params = ["variables", "versions"] - resp = _client._send_request("GET", path_params) - - for entry in resp: - if entry["software"] == software: - return entry["version"] - return None - - def parse_major_and_minor( - self, backend_version: str - ) -> Tuple[Optional[str], Optional[str]]: - """Extract major and minor version from full version. - - # Arguments - backend_version: The full version. - # Returns - (major, minor): The pair of major and minor parts of the version, or (None, None) if the version format is incorrect. - """ - - version_pattern = r"(\d+)\.(\d+)" - matches = re.match(version_pattern, backend_version) - - if matches is None: - return (None, None) - return matches.group(1), matches.group(2) - - def get_flyingduck_enabled(self) -> bool: - """Check if Flying Duck is enabled on the backend. - - # Returns - `True`: If flying duck is availalbe, `False` otherwise. - # Raises - `RestAPIError`: If unable to obtain the flag's value. - """ - return self.get_variable("enable_flyingduck") == "true" - - def get_loadbalancer_external_domain(self) -> str: - """Get domain of external loadbalancer. - - # Returns - `str`: The domain of external loadbalancer, if it is set up, otherwise empty string `""`. - """ - try: - return self.get_variable("loadbalancer_external_domain") - except RestAPIError: - return "" - - def get_service_discovery_domain(self) -> str: - """Get domain of service discovery server. - - # Returns - `str`: The domain of service discovery server, if it is set up, otherwise empty string `""`. - """ - try: - return self.get_variable("service_discovery_domain") - except RestAPIError: - return "" diff --git a/python/hopsworks_common/decorators.py b/python/hopsworks_common/decorators.py deleted file mode 100644 index fd83f290d..000000000 --- a/python/hopsworks_common/decorators.py +++ /dev/null @@ -1,86 +0,0 @@ -# -# Copyright 2022 Logical Clocks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import annotations - -import functools -import os - -from hopsworks_common.core.constants import ( - HAS_GREAT_EXPECTATIONS, - great_expectations_not_installed_message, -) - - -def not_connected(fn): - @functools.wraps(fn) - def if_not_connected(inst, *args, **kwargs): - if inst._connected: - raise HopsworksConnectionError - return fn(inst, *args, **kwargs) - - return if_not_connected - - -def connected(fn): - @functools.wraps(fn) - def if_connected(inst, *args, **kwargs): - if not inst._connected: - raise NoHopsworksConnectionError - return fn(inst, *args, **kwargs) - - return if_connected - - -class HopsworksConnectionError(Exception): - """Thrown when attempted to change connection attributes while connected.""" - - def __init__(self): - super().__init__( - "Connection is currently in use. Needs to be closed for modification." - ) - - -class NoHopsworksConnectionError(Exception): - """Thrown when attempted to perform operation on connection while not connected.""" - - def __init__(self): - super().__init__( - "Connection is not active. Needs to be connected for hopsworks operations." - ) - - -if os.environ.get("HOPSWORKS_RUN_WITH_TYPECHECK", False): - from typeguard import typechecked -else: - from typing import TypeVar - - _T = TypeVar("_T") - - def typechecked( - target: _T, - ) -> _T: - return target if target else typechecked - - -def uses_great_expectations(f): - @functools.wraps(f) - def g(*args, **kwds): - if not HAS_GREAT_EXPECTATIONS: - raise ModuleNotFoundError(great_expectations_not_installed_message) - return f(*args, **kwds) - - return g diff --git a/python/hsfs/client/__init__.py b/python/hsfs/client/__init__.py index 19e0feb8d..736b2006f 100644 --- a/python/hsfs/client/__init__.py +++ b/python/hsfs/client/__init__.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2020 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,28 +13,58 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations -from hopsworks_common.client import ( - auth, - base, - exceptions, - external, - get_instance, - hopsworks, - init, - online_store_rest_client, - stop, -) - - -__all__ = [ - auth, - base, - exceptions, - external, - get_instance, - hopsworks, - init, - online_store_rest_client, - stop, -] +from typing import Literal, Optional, Union + +from hsfs.client import external, hopsworks + + +_client = None + + +def init( + client_type: Union[Literal["hopsworks"], Literal["external"]], + host: Optional[str] = None, + port: Optional[int] = None, + project: Optional[str] = None, + engine: Optional[str] = None, + region_name: Optional[str] = None, + secrets_store=None, + hostname_verification: Optional[bool] = None, + trust_store_path: Optional[str] = None, + cert_folder: Optional[str] = None, + api_key_file: Optional[str] = None, + api_key_value: Optional[str] = None, +) -> None: + global _client + if not _client: + if client_type == "hopsworks": + _client = hopsworks.Client() + elif client_type == "external": + _client = external.Client( + host, + port, + project, + engine, + region_name, + secrets_store, + hostname_verification, + trust_store_path, + cert_folder, + api_key_file, + api_key_value, + ) + + +def get_instance() -> Union[hopsworks.Client, external.Client]: + global _client + if _client: + return _client + raise Exception("Couldn't find client. Try reconnecting to Hopsworks.") + + +def stop() -> None: + global _client + _client._close() + _client = None diff --git a/python/hsfs/client/auth.py b/python/hsfs/client/auth.py index e912b1daf..1556a5b4c 100644 --- a/python/hsfs/client/auth.py +++ b/python/hsfs/client/auth.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2020 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,16 +13,39 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations -from hopsworks_common.client.auth import ( - ApiKeyAuth, - BearerAuth, - OnlineStoreKeyAuth, -) +import requests -__all__ = [ - ApiKeyAuth, - BearerAuth, - OnlineStoreKeyAuth, -] +class BearerAuth(requests.auth.AuthBase): + """Class to encapsulate a Bearer token.""" + + def __init__(self, token: str) -> None: + self._token = token.strip() + + def __call__(self, r: requests.Request) -> requests.Request: + r.headers["Authorization"] = "Bearer " + self._token + return r + + +class ApiKeyAuth(requests.auth.AuthBase): + """Class to encapsulate an API key.""" + + def __init__(self, token: str) -> None: + self._token = token.strip() + + def __call__(self, r: requests.Request) -> requests.Request: + r.headers["Authorization"] = "ApiKey " + self._token + return r + + +class OnlineStoreKeyAuth(requests.auth.AuthBase): + """Class to encapsulate an API key.""" + + def __init__(self, token): + self._token = token.strip() + + def __call__(self, r): + r.headers["X-API-KEY"] = self._token + return r diff --git a/python/hsfs/client/base.py b/python/hsfs/client/base.py index 3ff35d800..eeb6eb369 100644 --- a/python/hsfs/client/base.py +++ b/python/hsfs/client/base.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2020 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,276 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations -from hopsworks_common.client.base import ( - Client, -) +import base64 +import os +import textwrap +import time +from pathlib import Path +import furl +import requests +import urllib3 +from hsfs.client import auth, exceptions +from hsfs.decorators import connected -__all__ = [ - Client, -] + +try: + import jks +except ImportError: + pass + + +urllib3.disable_warnings(urllib3.exceptions.SecurityWarning) +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + +class Client: + TOKEN_FILE = "token.jwt" + TOKEN_EXPIRED_RETRY_INTERVAL = 0.6 + TOKEN_EXPIRED_MAX_RETRIES = 10 + + APIKEY_FILE = "api.key" + REST_ENDPOINT = "REST_ENDPOINT" + DEFAULT_DATABRICKS_ROOT_VIRTUALENV_ENV = "DEFAULT_DATABRICKS_ROOT_VIRTUALENV_ENV" + HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST" + + def _get_verify(self, verify, trust_store_path): + """Get verification method for sending HTTP requests to Hopsworks. + + Credit to https://gist.github.com/gdamjan/55a8b9eec6cf7b771f92021d93b87b2c + + :param verify: perform hostname verification, 'true' or 'false' + :type verify: str + :param trust_store_path: path of the truststore locally if it was uploaded manually to + the external environment such as AWS Sagemaker + :type trust_store_path: str + :return: if verify is true and the truststore is provided, then return the trust store location + if verify is true but the truststore wasn't provided, then return true + if verify is false, then return false + :rtype: str or boolean + """ + if verify == "true": + if trust_store_path is not None: + return trust_store_path + else: + return True + + return False + + def _get_host_port_pair(self): + """ + Removes "http or https" from the rest endpoint and returns a list + [endpoint, port], where endpoint is on the format /path.. without http:// + + :return: a list [endpoint, port] + :rtype: list + """ + endpoint = self._base_url + if "http" in endpoint: + last_index = endpoint.rfind("/") + endpoint = endpoint[last_index + 1 :] + host, port = endpoint.split(":") + return host, port + + def _read_jwt(self): + """Retrieve jwt from local container.""" + return self._read_file(self.TOKEN_FILE) + + def _read_apikey(self): + """Retrieve apikey from local container.""" + return self._read_file(self.APIKEY_FILE) + + def _read_file(self, secret_file): + """Retrieve secret from local container.""" + with open(os.path.join(self._secrets_dir, secret_file), "r") as secret: + return secret.read() + + def _get_credentials(self, project_id): + """Makes a REST call to hopsworks for getting the project user certificates needed to connect to services such as Hive + + :param project_id: id of the project + :type project_id: int + :return: JSON response with credentials + :rtype: dict + """ + return self._send_request("GET", ["project", project_id, "credentials"]) + + def _write_pem_file(self, content: str, path: str) -> None: + with open(path, "w") as f: + f.write(content) + + @connected + def _send_request( + self, + method, + path_params, + query_params=None, + headers=None, + data=None, + stream=False, + files=None, + ): + """Send REST request to Hopsworks. + + Uses the client it is executed from. Path parameters are url encoded automatically. + + :param method: 'GET', 'PUT' or 'POST' + :type method: str + :param path_params: a list of path params to build the query url from starting after + the api resource, for example `["project", 119, "featurestores", 67]`. + :type path_params: list + :param query_params: A dictionary of key/value pairs to be added as query parameters, + defaults to None + :type query_params: dict, optional + :param headers: Additional header information, defaults to None + :type headers: dict, optional + :param data: The payload as a python dictionary to be sent as json, defaults to None + :type data: dict, optional + :param stream: Set if response should be a stream, defaults to False + :type stream: boolean, optional + :param files: dictionary for multipart encoding upload + :type files: dict, optional + :raises hsfs.client.exceptions.RestAPIError: Raised when request wasn't correctly received, understood or accepted + :return: Response json + :rtype: dict + """ + base_path_params = ["hopsworks-api", "api"] + f_url = furl.furl(self._base_url) + f_url.path.segments = base_path_params + path_params + url = str(f_url) + + request = requests.Request( + method, + url=url, + headers=headers, + data=data, + params=query_params, + auth=self._auth, + files=files, + ) + + prepped = self._session.prepare_request(request) + response = self._session.send(prepped, verify=self._verify, stream=stream) + + if response.status_code == 401 and self.REST_ENDPOINT in os.environ: + # refresh token and retry request - only on hopsworks + response = self._retry_token_expired( + request, stream, self.TOKEN_EXPIRED_RETRY_INTERVAL, 1 + ) + + if response.status_code // 100 != 2: + raise exceptions.RestAPIError(url, response) + + if stream: + return response + else: + # handle different success response codes + if len(response.content) == 0: + return None + return response.json() + + def _retry_token_expired(self, request, stream, wait, retries): + """Refresh the JWT token and retry the request. Only on Hopsworks. + As the token might take a while to get refreshed. Keep trying + """ + # Sleep the waited time before re-issuing the request + time.sleep(wait) + + self._auth = auth.BearerAuth(self._read_jwt()) + # Update request with the new token + request.auth = self._auth + prepped = self._session.prepare_request(request) + response = self._session.send(prepped, verify=self._verify, stream=stream) + + if response.status_code == 401 and retries < self.TOKEN_EXPIRED_MAX_RETRIES: + # Try again. + return self._retry_token_expired(request, stream, wait * 2, retries + 1) + else: + # If the number of retries have expired, the _send_request method + # will throw an exception to the user as part of the status_code validation. + return response + + def _close(self): + """Closes a client. Can be implemented for clean up purposes, not mandatory.""" + self._connected = False + + def _write_pem( + self, keystore_path, keystore_pw, truststore_path, truststore_pw, prefix + ): + ks = jks.KeyStore.load(Path(keystore_path), keystore_pw, try_decrypt_keys=True) + ts = jks.KeyStore.load( + Path(truststore_path), truststore_pw, try_decrypt_keys=True + ) + + ca_chain_path = os.path.join("/tmp", f"{prefix}_ca_chain.pem") + self._write_ca_chain(ks, ts, ca_chain_path) + + client_cert_path = os.path.join("/tmp", f"{prefix}_client_cert.pem") + self._write_client_cert(ks, client_cert_path) + + client_key_path = os.path.join("/tmp", f"{prefix}_client_key.pem") + self._write_client_key(ks, client_key_path) + + return ca_chain_path, client_cert_path, client_key_path + + def _write_ca_chain(self, ks, ts, ca_chain_path): + """ + Converts JKS keystore and truststore file into ca chain PEM to be compatible with Python libraries + """ + ca_chain = "" + for store in [ks, ts]: + for _, c in store.certs.items(): + ca_chain = ca_chain + self._bytes_to_pem_str(c.cert, "CERTIFICATE") + + with Path(ca_chain_path).open("w") as f: + f.write(ca_chain) + + def _write_client_cert(self, ks, client_cert_path): + """ + Converts JKS keystore file into client cert PEM to be compatible with Python libraries + """ + client_cert = "" + for _, pk in ks.private_keys.items(): + for c in pk.cert_chain: + client_cert = client_cert + self._bytes_to_pem_str(c[1], "CERTIFICATE") + + with Path(client_cert_path).open("w") as f: + f.write(client_cert) + + def _write_client_key(self, ks, client_key_path): + """ + Converts JKS keystore file into client key PEM to be compatible with Python libraries + """ + client_key = "" + for _, pk in ks.private_keys.items(): + client_key = client_key + self._bytes_to_pem_str( + pk.pkey_pkcs8, "PRIVATE KEY" + ) + + with Path(client_key_path).open("w") as f: + f.write(client_key) + + def _bytes_to_pem_str(self, der_bytes, pem_type): + """ + Utility function for creating PEM files + + Args: + der_bytes: DER encoded bytes + pem_type: type of PEM, e.g Certificate, Private key, or RSA private key + + Returns: + PEM String for a DER-encoded certificate or private key + """ + pem_str = "" + pem_str = pem_str + "-----BEGIN {}-----".format(pem_type) + "\n" + pem_str = ( + pem_str + + "\r\n".join( + textwrap.wrap(base64.b64encode(der_bytes).decode("ascii"), 64) + ) + + "\n" + ) + pem_str = pem_str + "-----END {}-----".format(pem_type) + "\n" + return pem_str diff --git a/python/hsfs/client/exceptions.py b/python/hsfs/client/exceptions.py index b34ef198f..7a7f67d5c 100644 --- a/python/hsfs/client/exceptions.py +++ b/python/hsfs/client/exceptions.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2020 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,38 +13,98 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations -from hopsworks_common.client.exceptions import ( - DatasetException, - DataValidationException, - EnvironmentException, - ExternalClientError, - FeatureStoreException, - GitException, - JobException, - JobExecutionException, - KafkaException, - OpenSearchException, - ProjectException, - RestAPIError, - UnknownSecretStorageError, - VectorDatabaseException, -) - - -__all__ = [ - DatasetException, - DataValidationException, - EnvironmentException, - ExternalClientError, - FeatureStoreException, - GitException, - JobException, - JobExecutionException, - KafkaException, - OpenSearchException, - ProjectException, - RestAPIError, - UnknownSecretStorageError, - VectorDatabaseException, -] +from enum import Enum +from typing import Any, Union + +import requests + + +class RestAPIError(Exception): + """REST Exception encapsulating the response object and url.""" + + class FeatureStoreErrorCode(int, Enum): + FEATURE_GROUP_COMMIT_NOT_FOUND = 270227 + STATISTICS_NOT_FOUND = 270228 + + def __eq__(self, other: Union[int, Any]) -> bool: + if isinstance(other, int): + return self.value == other + if isinstance(other, self.__class__): + return self is other + return False + + def __init__(self, url: str, response: requests.Response) -> None: + try: + error_object = response.json() + if isinstance(error_object, str): + error_object = {"errorMsg": error_object} + except Exception: + error_object = {} + message = ( + "Metadata operation error: (url: {}). Server response: \n" + "HTTP code: {}, HTTP reason: {}, body: {}, error code: {}, error msg: {}, user " + "msg: {}".format( + url, + response.status_code, + response.reason, + response.content, + error_object.get("errorCode", ""), + error_object.get("errorMsg", ""), + error_object.get("usrMsg", ""), + ) + ) + super().__init__(message) + self.url = url + self.response = response + + +class UnknownSecretStorageError(Exception): + """This exception will be raised if an unused secrets storage is passed as a parameter.""" + + +class FeatureStoreException(Exception): + """Generic feature store exception""" + + +class VectorDatabaseException(Exception): + # reason + REQUESTED_K_TOO_LARGE = "REQUESTED_K_TOO_LARGE" + REQUESTED_NUM_RESULT_TOO_LARGE = "REQUESTED_NUM_RESULT_TOO_LARGE" + OTHERS = "OTHERS" + + # info + REQUESTED_K_TOO_LARGE_INFO_K = "k" + REQUESTED_NUM_RESULT_TOO_LARGE_INFO_N = "n" + + def __init__(self, reason: str, message: str, info: str) -> None: + super().__init__(message) + self._info = info + self._reason = reason + + @property + def reason(self) -> str: + return self._reason + + @property + def info(self) -> str: + return self._info + + +class DataValidationException(FeatureStoreException): + """Raised when data validation fails only when using "STRICT" validation ingestion policy.""" + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class ExternalClientError(TypeError): + """Raised when external client cannot be initialized due to missing arguments.""" + + def __init__(self, missing_argument: str) -> None: + message = ( + "{0} cannot be of type NoneType, {0} is a non-optional " + "argument to connect to hopsworks from an external environment." + ).format(missing_argument) + super().__init__(message) diff --git a/python/hsfs/client/external.py b/python/hsfs/client/external.py index 1384b1c20..e99fc20b4 100644 --- a/python/hsfs/client/external.py +++ b/python/hsfs/client/external.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2020 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,370 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations -from hopsworks_common.client.external import ( - Client, -) +import base64 +import json +import logging +import os +import boto3 +import requests -__all__ = [ - Client, -] + +try: + from pyspark.sql import SparkSession +except ImportError: + pass + +from hsfs.client import auth, base, exceptions +from hsfs.client.exceptions import FeatureStoreException + + +_logger = logging.getLogger(__name__) + + +class Client(base.Client): + DEFAULT_REGION = "default" + SECRETS_MANAGER = "secretsmanager" + PARAMETER_STORE = "parameterstore" + LOCAL_STORE = "local" + + def __init__( + self, + host, + port, + project, + engine, + region_name, + secrets_store, + hostname_verification, + trust_store_path, + cert_folder, + api_key_file, + api_key_value, + ): + """Initializes a client in an external environment such as AWS Sagemaker.""" + _logger.info("Initializing external client") + if not host: + raise exceptions.ExternalClientError("host") + if not project: + raise exceptions.ExternalClientError("project") + + self._host = host + self._port = port + self._base_url = "https://" + self._host + ":" + str(self._port) + _logger.info("Base URL: %s", self._base_url) + self._project_name = project + _logger.debug("Project name: %s", self._project_name) + self._region_name = region_name or self.DEFAULT_REGION + _logger.debug("Region name: %s", self._region_name) + + if api_key_value is not None: + _logger.debug("Using provided API key value") + api_key = api_key_value + else: + _logger.debug("Querying secrets store for API key") + api_key = self._get_secret(secrets_store, "api-key", api_key_file) + + _logger.debug("Using api key to setup header authentification") + self._auth = auth.ApiKeyAuth(api_key) + + _logger.debug("Setting up requests session") + self._session = requests.session() + self._connected = True + + self._verify = self._get_verify(self._host, trust_store_path) + _logger.debug("Verify: %s", self._verify) + + project_info = self._get_project_info(self._project_name) + + self._project_id = str(project_info["projectId"]) + _logger.debug("Setting Project ID: %s", self._project_id) + + self._cert_key = None + self._cert_folder_base = None + + if engine == "python": + credentials = self._materialize_certs(cert_folder, host, project) + + self._write_pem_file(credentials["caChain"], self._get_ca_chain_path()) + self._write_pem_file( + credentials["clientCert"], self._get_client_cert_path() + ) + self._write_pem_file(credentials["clientKey"], self._get_client_key_path()) + + elif engine == "spark": + # When using the Spark engine with metastore connection, the certificates + # are needed when the application starts (before user code is run) + # So in this case, we can't materialize the certificates on the fly. + _logger.debug("Running in Spark environment, initializing Spark session") + _spark_session = SparkSession.builder.enableHiveSupport().getOrCreate() + + self._validate_spark_configuration(_spark_session) + with open( + _spark_session.conf.get("spark.hadoop.hops.ssl.keystores.passwd.name"), + "r", + ) as f: + self._cert_key = f.read() + + self._trust_store_path = _spark_session.conf.get( + "spark.hadoop.hops.ssl.trustore.name" + ) + self._key_store_path = _spark_session.conf.get( + "spark.hadoop.hops.ssl.keystore.name" + ) + elif engine == "spark-no-metastore": + _logger.debug( + "Running in Spark environment with no metastore, initializing Spark session" + ) + _spark_session = SparkSession.builder.getOrCreate() + self._materialize_certs(cert_folder, host, project) + + # Set credentials location in the Spark configuration + # Set other options in the Spark configuration + configuration_dict = { + "hops.ssl.trustore.name": self._trust_store_path, + "hops.ssl.keystore.name": self._key_store_path, + "hops.ssl.keystores.passwd.name": self._cert_key_path, + "fs.permissions.umask-mode": "0002", + "fs.hopsfs.impl": "io.hops.hopsfs.client.HopsFileSystem", + "hops.rpc.socket.factory.class.default": "io.hops.hadoop.shaded.org.apache.hadoop.net.HopsSSLSocketFactory", + "client.rpc.ssl.enabled.protocol": "TLSv1.2", + "hops.ssl.hostname.verifier": "ALLOW_ALL", + "hops.ipc.server.ssl.enabled": "true", + } + + for conf_key, conf_value in configuration_dict.items(): + _spark_session._jsc.hadoopConfiguration().set(conf_key, conf_value) + + def _materialize_certs(self, cert_folder, host, project): + self._cert_folder_base = cert_folder + self._cert_folder = os.path.join(cert_folder, host, project) + self._trust_store_path = os.path.join(self._cert_folder, "trustStore.jks") + self._key_store_path = os.path.join(self._cert_folder, "keyStore.jks") + + if os.path.exists(self._cert_folder): + _logger.debug( + f"Running in Python environment, reading certificates from certificates folder {cert_folder}" + ) + _logger.debug("Found certificates: %s", os.listdir(cert_folder)) + else: + _logger.debug( + f"Running in Python environment, creating certificates folder {cert_folder}" + ) + os.makedirs(self._cert_folder, exist_ok=True) + + credentials = self._get_credentials(self._project_id) + self._write_b64_cert_to_bytes( + str(credentials["kStore"]), + path=self._get_jks_key_store_path(), + ) + self._write_b64_cert_to_bytes( + str(credentials["tStore"]), + path=self._get_jks_trust_store_path(), + ) + self._cert_key = str(credentials["password"]) + self._cert_key_path = os.path.join(self._cert_folder, "material_passwd") + with open(self._cert_key_path, "w") as f: + f.write(str(credentials["password"])) + + # Return the credentials object for the Python engine to materialize the pem files. + return credentials + + def _validate_spark_configuration(self, _spark_session): + exception_text = "Spark is misconfigured for communication with Hopsworks, missing or invalid property: " + + configuration_dict = { + "spark.hadoop.hops.ssl.trustore.name": None, + "spark.hadoop.hops.rpc.socket.factory.class.default": "io.hops.hadoop.shaded.org.apache.hadoop.net.HopsSSLSocketFactory", + "spark.serializer": "org.apache.spark.serializer.KryoSerializer", + "spark.hadoop.hops.ssl.hostname.verifier": "ALLOW_ALL", + "spark.hadoop.hops.ssl.keystore.name": None, + "spark.hadoop.fs.hopsfs.impl": "io.hops.hopsfs.client.HopsFileSystem", + "spark.hadoop.hops.ssl.keystores.passwd.name": None, + "spark.hadoop.hops.ipc.server.ssl.enabled": "true", + "spark.hadoop.client.rpc.ssl.enabled.protocol": "TLSv1.2", + "spark.hadoop.hive.metastore.uris": None, + "spark.sql.hive.metastore.jars": None, + } + _logger.debug("Configuration dict: %s", configuration_dict) + + for key, value in configuration_dict.items(): + _logger.debug("Validating key: %s", key) + if not ( + _spark_session.conf.get(key, "not_found") != "not_found" + and (value is None or _spark_session.conf.get(key, None) == value) + ): + raise FeatureStoreException(exception_text + key) + + def _close(self): + """Closes a client and deletes certificates.""" + _logger.info("Closing external client and cleaning up certificates.") + if self._cert_folder_base is None: + _logger.debug("No certificates to clean up.") + # On external Spark clients (Databricks, Spark Cluster), + # certificates need to be provided before the Spark application starts. + return + + # Clean up only on AWS + _logger.debug("Cleaning up certificates. AWS only.") + self._cleanup_file(self._get_jks_key_store_path()) + self._cleanup_file(self._get_jks_trust_store_path()) + self._cleanup_file(os.path.join(self._cert_folder, "material_passwd")) + self._cleanup_file(self._get_ca_chain_path()) + self._cleanup_file(self._get_client_cert_path()) + self._cleanup_file(self._get_client_key_path()) + + try: + # delete project level + os.rmdir(self._cert_folder) + # delete host level + os.rmdir(os.path.dirname(self._cert_folder)) + # on AWS base dir will be empty, and can be deleted otherwise raises OSError + os.rmdir(self._cert_folder_base) + except OSError: + pass + self._connected = False + + def _get_jks_trust_store_path(self): + _logger.debug("Getting trust store path: %s", self._trust_store_path) + return self._trust_store_path + + def _get_jks_key_store_path(self): + _logger.debug("Getting key store path: %s", self._key_store_path) + return self._key_store_path + + def _get_ca_chain_path(self) -> str: + path = os.path.join(self._cert_folder, "ca_chain.pem") + _logger.debug(f"Getting ca chain path {path}") + return path + + def _get_client_cert_path(self) -> str: + path = os.path.join(self._cert_folder, "client_cert.pem") + _logger.debug(f"Getting client cert path {path}") + return path + + def _get_client_key_path(self) -> str: + path = os.path.join(self._cert_folder, "client_key.pem") + _logger.debug(f"Getting client key path {path}") + return path + + def _get_secret(self, secrets_store, secret_key=None, api_key_file=None): + """Returns secret value from the AWS Secrets Manager or Parameter Store. + + :param secrets_store: the underlying secrets storage to be used, e.g. `secretsmanager` or `parameterstore` + :type secrets_store: str + :param secret_key: key for the secret value, e.g. `api-key`, `cert-key`, `trust-store`, `key-store`, defaults to None + :type secret_key: str, optional + :param api_key_file: path to a file containing an api key, defaults to None + :type api_key_file: str optional + :raises hsfs.client.exceptions.ExternalClientError: `api_key_file` needs to be set for local mode + :raises hsfs.client.exceptions.UnknownSecretStorageError: Provided secrets storage not supported + :return: secret + :rtype: str + """ + _logger.debug(f"Querying secrets store {secrets_store} for secret {secret_key}") + if secrets_store == self.SECRETS_MANAGER: + return self._query_secrets_manager(secret_key) + elif secrets_store == self.PARAMETER_STORE: + return self._query_parameter_store(secret_key) + elif secrets_store == self.LOCAL_STORE: + if not api_key_file: + raise exceptions.ExternalClientError( + "api_key_file needs to be set for local mode" + ) + _logger.debug(f"Reading api key from {api_key_file}") + with open(api_key_file) as f: + return f.readline().strip() + else: + raise exceptions.UnknownSecretStorageError( + "Secrets storage " + secrets_store + " is not supported." + ) + + def _query_secrets_manager(self, secret_key): + _logger.debug("Querying secrets manager for secret key: %s", secret_key) + secret_name = "hopsworks/role/" + self._assumed_role() + args = {"service_name": "secretsmanager"} + region_name = self._get_region() + if region_name: + args["region_name"] = region_name + client = boto3.client(**args) + get_secret_value_response = client.get_secret_value(SecretId=secret_name) + return json.loads(get_secret_value_response["SecretString"])[secret_key] + + def _assumed_role(self): + _logger.debug("Getting assumed role") + client = boto3.client("sts") + response = client.get_caller_identity() + # arns for assumed roles in SageMaker follow the following schema + # arn:aws:sts::123456789012:assumed-role/my-role-name/my-role-session-name + local_identifier = response["Arn"].split(":")[-1].split("/") + if len(local_identifier) != 3 or local_identifier[0] != "assumed-role": + raise Exception( + "Failed to extract assumed role from arn: " + response["Arn"] + ) + return local_identifier[1] + + def _get_region(self): + if self._region_name != self.DEFAULT_REGION: + _logger.debug(f"Region name is not default, returning {self._region_name}") + return self._region_name + else: + _logger.debug("Region name is default, returning None") + return None + + def _query_parameter_store(self, secret_key): + _logger.debug("Querying parameter store for secret key: %s", secret_key) + args = {"service_name": "ssm"} + region_name = self._get_region() + if region_name: + args["region_name"] = region_name + client = boto3.client(**args) + name = "/hopsworks/role/" + self._assumed_role() + "/type/" + secret_key + return client.get_parameter(Name=name, WithDecryption=True)["Parameter"][ + "Value" + ] + + def _get_project_info(self, project_name): + """Makes a REST call to hopsworks to get all metadata of a project for the provided project. + + :param project_name: the name of the project + :type project_name: str + :return: JSON response with project info + :rtype: dict + """ + _logger.debug("Getting project info for project: %s", project_name) + return self._send_request("GET", ["project", "getProjectInfo", project_name]) + + def _write_b64_cert_to_bytes(self, b64_string, path): + """Converts b64 encoded certificate to bytes file . + + :param b64_string: b64 encoded string of certificate + :type b64_string: str + :param path: path where file is saved, including file name. e.g. /path/key-store.jks + :type path: str + """ + _logger.debug(f"Writing b64 encoded certificate to {path}") + with open(path, "wb") as f: + cert_b64 = base64.b64decode(b64_string) + f.write(cert_b64) + + def _cleanup_file(self, file_path): + """Removes local files with `file_path`.""" + _logger.debug(f"Cleaning up file {file_path}") + try: + os.remove(file_path) + except OSError: + pass + + def replace_public_host(self, url): + """no need to replace as we are already in external client""" + return url + + def _is_external(self) -> bool: + return True + + @property + def host(self) -> str: + return self._host diff --git a/python/hsfs/client/hopsworks.py b/python/hsfs/client/hopsworks.py index 1384b1c20..2134756b1 100644 --- a/python/hsfs/client/hopsworks.py +++ b/python/hsfs/client/hopsworks.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2020 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,173 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations -from hopsworks_common.client.external import ( - Client, -) +import os +from pathlib import Path +import requests +from hsfs.client import auth, base -__all__ = [ - Client, -] + +try: + import jks +except ImportError: + pass + + +class Client(base.Client): + REQUESTS_VERIFY = "REQUESTS_VERIFY" + DOMAIN_CA_TRUSTSTORE_PEM = "DOMAIN_CA_TRUSTSTORE_PEM" + PROJECT_ID = "HOPSWORKS_PROJECT_ID" + PROJECT_NAME = "HOPSWORKS_PROJECT_NAME" + HADOOP_USER_NAME = "HADOOP_USER_NAME" + MATERIAL_DIRECTORY = "MATERIAL_DIRECTORY" + HDFS_USER = "HDFS_USER" + T_CERTIFICATE = "t_certificate" + K_CERTIFICATE = "k_certificate" + TRUSTSTORE_SUFFIX = "__tstore.jks" + KEYSTORE_SUFFIX = "__kstore.jks" + PEM_CA_CHAIN = "ca_chain.pem" + CERT_KEY_SUFFIX = "__cert.key" + MATERIAL_PWD = "material_passwd" + SECRETS_DIR = "SECRETS_DIR" + + def __init__(self): + """Initializes a client being run from a job/notebook directly on Hopsworks.""" + self._base_url = self._get_hopsworks_rest_endpoint() + self._host, self._port = self._get_host_port_pair() + self._secrets_dir = ( + os.environ[self.SECRETS_DIR] if self.SECRETS_DIR in os.environ else "" + ) + self._cert_key = self._get_cert_pw() + trust_store_path = self._get_trust_store_path() + hostname_verification = ( + os.environ[self.REQUESTS_VERIFY] + if self.REQUESTS_VERIFY in os.environ + else "true" + ) + self._project_id = os.environ[self.PROJECT_ID] + self._project_name = self._project_name() + try: + self._auth = auth.BearerAuth(self._read_jwt()) + except FileNotFoundError: + self._auth = auth.ApiKeyAuth(self._read_apikey()) + self._verify = self._get_verify(hostname_verification, trust_store_path) + self._session = requests.session() + + self._connected = True + + credentials = self._get_credentials(self._project_id) + + self._write_pem_file(credentials["caChain"], self._get_ca_chain_path()) + self._write_pem_file(credentials["clientCert"], self._get_client_cert_path()) + self._write_pem_file(credentials["clientKey"], self._get_client_key_path()) + + def _get_hopsworks_rest_endpoint(self): + """Get the hopsworks REST endpoint for making requests to the REST API.""" + return os.environ[self.REST_ENDPOINT] + + def _get_trust_store_path(self): + """Convert truststore from jks to pem and return the location""" + ca_chain_path = Path(self.PEM_CA_CHAIN) + if not ca_chain_path.exists(): + ks = jks.KeyStore.load( + self._get_jks_key_store_path(), self._cert_key, try_decrypt_keys=True + ) + ts = jks.KeyStore.load( + self._get_jks_trust_store_path(), self._cert_key, try_decrypt_keys=True + ) + self._write_ca_chain( + ks, + ts, + ca_chain_path, + ) + return str(ca_chain_path) + + def _get_ca_chain_path(self) -> str: + return os.path.join("/tmp", "ca_chain.pem") + + def _get_client_cert_path(self) -> str: + return os.path.join("/tmp", "client_cert.pem") + + def _get_client_key_path(self) -> str: + return os.path.join("/tmp", "client_key.pem") + + def _get_jks_trust_store_path(self): + """ + Get truststore location + + Returns: + truststore location + """ + t_certificate = Path(self.T_CERTIFICATE) + if t_certificate.exists(): + return str(t_certificate) + else: + username = os.environ[self.HADOOP_USER_NAME] + material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) + return str(material_directory.joinpath(username + self.TRUSTSTORE_SUFFIX)) + + def _get_jks_key_store_path(self): + """ + Get keystore location + + Returns: + keystore location + """ + k_certificate = Path(self.K_CERTIFICATE) + if k_certificate.exists(): + return str(k_certificate) + else: + username = os.environ[self.HADOOP_USER_NAME] + material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) + return str(material_directory.joinpath(username + self.KEYSTORE_SUFFIX)) + + def _project_name(self): + try: + return os.environ[self.PROJECT_NAME] + except KeyError: + pass + + hops_user = self._project_user() + hops_user_split = hops_user.split( + "__" + ) # project users have username project__user + project = hops_user_split[0] + return project + + def _project_user(self): + try: + hops_user = os.environ[self.HADOOP_USER_NAME] + except KeyError: + hops_user = os.environ[self.HDFS_USER] + return hops_user + + def _get_cert_pw(self): + """ + Get keystore password from local container + + Returns: + Certificate password + """ + pwd_path = Path(self.MATERIAL_PWD) + if not pwd_path.exists(): + username = os.environ[self.HADOOP_USER_NAME] + material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) + pwd_path = material_directory.joinpath(username + self.CERT_KEY_SUFFIX) + + with pwd_path.open() as f: + return f.read() + + def replace_public_host(self, url): + """replace hostname to public hostname set in HOPSWORKS_PUBLIC_HOST""" + ui_url = url._replace(netloc=os.environ[self.HOPSWORKS_PUBLIC_HOST]) + return ui_url + + def _is_external(self): + return False + + @property + def host(self): + return self._host diff --git a/python/hsfs/client/online_store_rest_client.py b/python/hsfs/client/online_store_rest_client.py index c75be81b7..b733269a1 100644 --- a/python/hsfs/client/online_store_rest_client.py +++ b/python/hsfs/client/online_store_rest_client.py @@ -13,16 +13,372 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations -from hopsworks_common.client.online_store_rest_client import ( - OnlineStoreRestClientSingleton, - get_instance, - init_or_reset_online_store_rest_client, -) +import logging +from typing import Any, Dict, List, Optional, Union +from warnings import warn +import requests +import requests.adapters +from furl import furl +from hsfs import client +from hsfs.client.exceptions import FeatureStoreException +from hsfs.core import variable_api -__all__ = [ - OnlineStoreRestClientSingleton, - get_instance, - init_or_reset_online_store_rest_client, -] + +_logger = logging.getLogger(__name__) + +_online_store_rest_client = None + + +def init_or_reset_online_store_rest_client( + transport: Optional[ + Union[requests.adapters.HTTPAdapter, requests.adapters.BaseAdapter] + ] = None, + optional_config: Optional[Dict[str, Any]] = None, + reset_client: bool = False, +): + global _online_store_rest_client + if not _online_store_rest_client: + _online_store_rest_client = OnlineStoreRestClientSingleton( + transport=transport, optional_config=optional_config + ) + elif reset_client: + _online_store_rest_client.reset_client( + transport=transport, optional_config=optional_config + ) + else: + _logger.warning( + "Online Store Rest Client is already initialised. To reset connection or/and override configuration, " + + "use reset_online_store_rest_client flag.", + stacklevel=2, + ) + + +def get_instance() -> OnlineStoreRestClientSingleton: + global _online_store_rest_client + if _online_store_rest_client is None: + _logger.warning( + "Online Store Rest Client is not initialised. Initialising with default configuration." + ) + _online_store_rest_client = OnlineStoreRestClientSingleton() + _logger.debug("Accessing global Online Store Rest Client instance.") + return _online_store_rest_client + + +class OnlineStoreRestClientSingleton: + HOST = "host" + PORT = "port" + VERIFY_CERTS = "verify_certs" + USE_SSL = "use_ssl" + CA_CERTS = "ca_certs" + HTTP_AUTHORIZATION = "http_authorization" + TIMEOUT = "timeout" + SERVER_API_VERSION = "server_api_version" + API_KEY = "api_key" + _DEFAULT_ONLINE_STORE_REST_CLIENT_PORT = 4406 + _DEFAULT_ONLINE_STORE_REST_CLIENT_TIMEOUT_SECOND = 2 + _DEFAULT_ONLINE_STORE_REST_CLIENT_VERIFY_CERTS = True + _DEFAULT_ONLINE_STORE_REST_CLIENT_USE_SSL = True + _DEFAULT_ONLINE_STORE_REST_CLIENT_SERVER_API_VERSION = "0.1.0" + _DEFAULT_ONLINE_STORE_REST_CLIENT_HTTP_AUTHORIZATION = "X-API-KEY" + + def __init__( + self, + transport: Optional[ + Union[requests.adapaters.HTTPadapter, requests.adapters.BaseAdapter] + ] = None, + optional_config: Optional[Dict[str, Any]] = None, + ): + _logger.debug( + f"Initialising Online Store Rest Client {'with optional configuration' if optional_config else ''}." + ) + if optional_config: + _logger.debug(f"Optional Config: {optional_config!r}") + self._check_hopsworks_connection() + self.variable_api = variable_api.VariableApi() + self._auth: client.auth.OnlineStoreKeyAuth + self._session: requests.Session + self._current_config: Dict[str, Any] + self._base_url: furl + self._setup_rest_client( + transport=transport, + optional_config=optional_config, + use_current_config=False, + ) + self.is_connected() + + def reset_client( + self, + transport: Optional[ + Union[requests.adapters.HttpAdapter, requests.adapters.BaseAdapter] + ] = None, + optional_config: Optional[Dict[str, Any]] = None, + ): + _logger.debug( + f"Resetting Online Store Rest Client {'with optional configuration' if optional_config else ''}." + ) + if optional_config: + _logger.debug(f"Optional Config: {optional_config}") + self._check_hopsworks_connection() + if hasattr(self, "_session") and self._session: + _logger.debug("Closing existing session.") + self._session.close() + delattr(self, "_session") + self._setup_rest_client( + transport=transport, + optional_config=optional_config, + use_current_config=False if optional_config else True, + ) + + def _setup_rest_client( + self, + transport: Optional[ + Union[requests.adapters.HttpAdapter, requests.adapters.BaseAdapter] + ] = None, + optional_config: Optional[Dict[str, Any]] = None, + use_current_config: bool = True, + ): + _logger.debug("Setting up Online Store Rest Client.") + if optional_config and not isinstance(optional_config, dict): + raise ValueError( + "optional_config must be a dictionary. See documentation for allowed keys and values." + ) + _logger.debug("Optional Config: %s", optional_config) + if not use_current_config: + _logger.debug( + "Retrieving default configuration for Online Store REST Client." + ) + self._current_config = self._get_default_client_config() + if optional_config: + _logger.debug( + "Updating default configuration with provided optional configuration." + ) + self._current_config.update(optional_config) + + self._set_auth(optional_config) + if not hasattr(self, "_session") or not self._session: + _logger.debug("Initialising new requests session.") + self._session = requests.Session() + else: + raise ValueError( + "Use the init_or_reset_online_store_connection method with reset_connection flag set " + + "to True to reset the online_store_client_connection" + ) + if transport is not None: + _logger.debug("Setting custom transport adapter.") + self._session.mount("https://", transport) + self._session.mount("http://", transport) + + if not self._current_config[self.VERIFY_CERTS]: + _logger.warning( + "Disabling SSL certificate verification. This is not recommended for production environments." + ) + self._session.verify = False + else: + _logger.debug( + f"Setting SSL certificate verification using CA Certs path: {self._current_config[self.CA_CERTS]}" + ) + self._session.verify = self._current_config[self.CA_CERTS] + + # Set base_url + scheme = "https" if self._current_config[self.USE_SSL] else "http" + self._base_url = furl( + f"{scheme}://{self._current_config[self.HOST]}:{self._current_config[self.PORT]}/{self._current_config[self.SERVER_API_VERSION]}" + ) + + assert ( + self._session is not None + ), "Online Store REST Client failed to initialise." + assert ( + self._auth is not None + ), "Online Store REST Client Authentication failed to initialise. Check API Key." + assert ( + self._base_url is not None + ), "Online Store REST Client Base URL failed to initialise. Check host and port parameters." + assert ( + self._current_config is not None + ), "Online Store REST Client Configuration failed to initialise." + + def _get_default_client_config(self) -> Dict[str, Any]: + _logger.debug("Retrieving default configuration for Online Store REST Client.") + default_config = self._get_default_static_parameters_config() + default_config.update(self._get_default_dynamic_parameters_config()) + return default_config + + def _get_default_static_parameters_config(self) -> Dict[str, Any]: + _logger.debug( + "Retrieving default static configuration for Online Store REST Client." + ) + return { + self.TIMEOUT: self._DEFAULT_ONLINE_STORE_REST_CLIENT_TIMEOUT_SECOND, + self.VERIFY_CERTS: self._DEFAULT_ONLINE_STORE_REST_CLIENT_VERIFY_CERTS, + self.USE_SSL: self._DEFAULT_ONLINE_STORE_REST_CLIENT_USE_SSL, + self.SERVER_API_VERSION: self._DEFAULT_ONLINE_STORE_REST_CLIENT_SERVER_API_VERSION, + self.HTTP_AUTHORIZATION: self._DEFAULT_ONLINE_STORE_REST_CLIENT_HTTP_AUTHORIZATION, + } + + def _get_default_dynamic_parameters_config( + self, + ) -> Dict[str, Any]: + _logger.debug( + "Retrieving default dynamic configuration for Online Store REST Client." + ) + url = furl(self._get_rondb_rest_server_endpoint()) + _logger.debug(f"Default RonDB Rest Server host and port: {url.host}:{url.port}") + _logger.debug( + f"Using CA Certs from Hopsworks Client: {client.get_instance()._get_ca_chain_path()}" + ) + return { + self.HOST: url.host, + self.PORT: url.port, + self.CA_CERTS: client.get_instance()._get_ca_chain_path(), + } + + def _get_rondb_rest_server_endpoint(self) -> str: + """Retrieve RonDB Rest Server endpoint based on whether the client is running internally or externally. + + If the client is running externally, the endpoint is retrieved via the loadbalancer. + If the client is running internally, the endpoint is retrieved via (consul) service discovery. + The default port for the RonDB Rest Server is 4406 and always used unless specifying a different port + in the configuration. + + Returns: + str: RonDB Rest Server endpoint with default port. + """ + if client.get_instance()._is_external(): + _logger.debug( + "External Online Store REST Client : Retrieving RonDB Rest Server endpoint via loadbalancer." + ) + external_domain = self.variable_api.get_loadbalancer_external_domain() + if external_domain == "": + _logger.debug( + "External Online Store REST Client : Loadbalancer external domain is not set. Using client host as endpoint." + ) + external_domain = client.get_instance().host + default_url = f"https://{external_domain}:{self._DEFAULT_ONLINE_STORE_REST_CLIENT_PORT}" + _logger.debug( + f"External Online Store REST Client : Default RonDB Rest Server endpoint: {default_url}" + ) + return default_url + else: + _logger.debug( + "Internal Online Store REST Client : Retrieving RonDB Rest Server endpoint via service discovery." + ) + service_discovery_domain = self.variable_api.get_service_discovery_domain() + if service_discovery_domain == "": + raise FeatureStoreException("Service discovery domain is not set.") + default_url = f"https://rdrs.service.{service_discovery_domain}:{self._DEFAULT_ONLINE_STORE_REST_CLIENT_PORT}" + _logger.debug( + f"Internal Online Store REST Client : Default RonDB Rest Server endpoint: {default_url}" + ) + return default_url + + def send_request( + self, + method: str, + path_params: List[str], + headers: Optional[Dict[str, Any]] = None, + data: Optional[str] = None, + ) -> requests.Response: + url = self._base_url.copy() + url.path.segments.extend(path_params) + _logger.debug(f"Sending {method} request to {url.url}.") + _logger.debug(f"Provided Data: {data}") + _logger.debug(f"Provided Headers: {headers}") + prepped_request = self._session.prepare_request( + requests.Request( + method, url=url.url, headers=headers, data=data, auth=self.auth + ) + ) + timeout = self._current_config[self.TIMEOUT] + return self._session.send( + prepped_request, + # compatibility with 3.7 + timeout=timeout if timeout < 500 else timeout / 1000, + ) + + def _check_hopsworks_connection(self) -> None: + _logger.debug("Checking Hopsworks connection.") + assert ( + client.get_instance() is not None and client.get_instance()._connected + ), """Hopsworks Client is not connected. Please connect to Hopsworks cluster + via hopsworks.login or hsfs.connection before initialising the Online Store REST Client. + """ + _logger.debug("Hopsworks connection is active.") + + def _set_auth(self, optional_config: Optional[Dict[str, Any]] = None) -> None: + """Set authentication object for the Online Store REST Client. + + RonDB Rest Server uses Hopsworks Api Key to authenticate requests via the X-API-KEY header by default. + The api key determines the permissions of the user making the request for access to a given Feature Store. + """ + _logger.debug("Setting authentication for Online Store REST Client.") + if client.get_instance()._is_external(): + assert hasattr( + client.get_instance()._auth, "_token" + ), "External client must use API Key authentication. Contact your system administrator." + _logger.debug( + "External Online Store REST Client : Setting authentication using Hopsworks Client API Key." + ) + self._auth = client.auth.OnlineStoreKeyAuth( + client.get_instance()._auth._token + ) + elif isinstance(optional_config, dict) and optional_config.get( + self.API_KEY, False + ): + _logger.debug( + "Setting authentication using provided API Key from optional configuration." + ) + self._auth = client.auth.OnlineStoreKeyAuth(optional_config[self.API_KEY]) + elif hasattr(self, "_auth") and self._auth is not None: + _logger.debug( + "Authentication for Online Store REST Client is already set. Using existing authentication api key." + ) + else: + raise FeatureStoreException( + "RonDB Rest Server uses Hopsworks Api Key to authenticate request." + + f"Provide a configuration with the {self.API_KEY} key." + ) + + def is_connected(self): + """If Online Store Rest Client is initialised, ping RonDB Rest Server to ensure connection is active.""" + if self._session is None: + _logger.debug( + "Checking Online Store REST Client is connected. Session is not initialised." + ) + raise FeatureStoreException("Online Store REST Client is not initialised.") + + _logger.debug( + "Checking Online Store REST Client is connected. Pinging RonDB Rest Server." + ) + if not self.send_request("GET", ["ping"]): + warn("Ping failed, RonDB Rest Server is not reachable.", stacklevel=2) + return False + return True + + @property + def session(self) -> requests.Session: + """Requests session object used to send requests to the Online Store REST API.""" + return self._session + + @property + def base_url(self) -> furl: + """Base URL for the Online Store REST API. + + This the url of the RonDB REST Server and should not be confused with the Opensearch Vector DB which also serves as an Online Store for features belonging to Feature Group containing embeddings.""" + return self._base_url + + @property + def current_config(self) -> Dict[str, Any]: + """Current configuration of the Online Store REST Client.""" + return self._current_config + + @property + def auth(self) -> "client.auth.OnlineStoreKeyAuth": + """Authentication object used to authenticate requests to the Online Store REST API. + + Extends the requests.auth.AuthBase class. + """ + return self._auth diff --git a/python/hsfs/core/constants.py b/python/hsfs/core/constants.py index a9bc0b1df..d6af38018 100644 --- a/python/hsfs/core/constants.py +++ b/python/hsfs/core/constants.py @@ -1,40 +1,35 @@ -# -# Copyright 2024 Hopsworks AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# +import importlib.util -from hopsworks_common.core.constants import ( - HAS_AIOMYSQL, - HAS_AVRO, - HAS_CONFLUENT_KAFKA, - HAS_FAST_AVRO, - HAS_GREAT_EXPECTATIONS, - HAS_NUMPY, - HAS_SQLALCHEMY, - great_expectations_not_installed_message, - initialise_expectation_suite_for_single_expectation_api_message, + +# Avro +HAS_FAST_AVRO: bool = importlib.util.find_spec("fastavro") is not None +HAS_AVRO: bool = importlib.util.find_spec("avro") is not None + +# Confluent Kafka +HAS_CONFLUENT_KAFKA: bool = importlib.util.find_spec("confluent_kafka") is not None +confluent_kafka_not_installed_message = ( + "Confluent Kafka package not found. " + "If you want to use Kafka with Hopsworks you can install the corresponding extras " + """`pip install hopsworks[python]` or `pip install "hopsworks[python]"` if using zsh. """ + "You can also install confluent-kafka directly in your environment e.g `pip install confluent-kafka`. " + "You will need to restart your kernel if applicable." +) +# Data Validation / Great Expectations +HAS_GREAT_EXPECTATIONS: bool = ( + importlib.util.find_spec("great_expectations") is not None +) +great_expectations_not_installed_message = ( + "Great Expectations package not found. " + "If you want to use data validation with Hopsworks you can install the corresponding extras " + """`pip install hopsworks[great_expectations]` or `pip install "hopsworks[great_expectations]"` if using zsh. """ + "You can also install great-expectations directly in your environment e.g `pip install great-expectations`. " + "You will need to restart your kernel if applicable." ) +initialise_expectation_suite_for_single_expectation_api_message = "Initialize Expectation Suite by attaching to a Feature Group to enable single expectation API" +# Numpy +HAS_NUMPY: bool = importlib.util.find_spec("numpy") is not None -__all__ = [ - HAS_AIOMYSQL, - HAS_AVRO, - HAS_CONFLUENT_KAFKA, - HAS_FAST_AVRO, - HAS_GREAT_EXPECTATIONS, - HAS_NUMPY, - HAS_SQLALCHEMY, - great_expectations_not_installed_message, - initialise_expectation_suite_for_single_expectation_api_message, -] +# SQL packages +HAS_SQLALCHEMY: bool = importlib.util.find_spec("sqlalchemy") is not None +HAS_AIOMYSQL: bool = importlib.util.find_spec("aiomysql") is not None diff --git a/python/hsfs/core/variable_api.py b/python/hsfs/core/variable_api.py index 9d6e9765f..b499bd9b4 100644 --- a/python/hsfs/core/variable_api.py +++ b/python/hsfs/core/variable_api.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2022 Hopsworks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,66 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations -from hopsworks_common.core.variable_api import ( - VariableApi, -) +import re +from hsfs import client +from hsfs.client.exceptions import RestAPIError -__all__ = [ - VariableApi, -] + +class VariableApi: + def get_version(self, software: str): + _client = client.get_instance() + path_params = [ + "variables", + "versions", + ] + + resp = _client._send_request("GET", path_params) + for entry in resp: + if entry["software"] == software: + return entry["version"] + return None + + def parse_major_and_minor(self, backend_version): + version_pattern = r"(\d+)\.(\d+)" + matches = re.match(version_pattern, backend_version) + + return matches.group(1), matches.group(2) + + def get_flyingduck_enabled(self): + _client = client.get_instance() + path_params = [ + "variables", + "enable_flyingduck", + ] + + resp = _client._send_request("GET", path_params) + return resp["successMessage"] == "true" + + def get_loadbalancer_external_domain(self): + _client = client.get_instance() + path_params = [ + "variables", + "loadbalancer_external_domain", + ] + + try: + resp = _client._send_request("GET", path_params) + return resp["successMessage"] + except RestAPIError: + return "" + + def get_service_discovery_domain(self): + _client = client.get_instance() + path_params = [ + "variables", + "service_discovery_domain", + ] + + try: + resp = _client._send_request("GET", path_params) + return resp["successMessage"] + except RestAPIError: + return "" diff --git a/python/hsfs/decorators.py b/python/hsfs/decorators.py index 1165a2daa..3ce15277f 100644 --- a/python/hsfs/decorators.py +++ b/python/hsfs/decorators.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2020 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,22 +13,73 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations -from hopsworks_common.decorators import ( - HopsworksConnectionError, - NoHopsworksConnectionError, - connected, - not_connected, - typechecked, - uses_great_expectations, +import functools +import os + +from hsfs.core.constants import ( + HAS_GREAT_EXPECTATIONS, + great_expectations_not_installed_message, ) -__all__ = [ - HopsworksConnectionError, - NoHopsworksConnectionError, - connected, - not_connected, - typechecked, - uses_great_expectations, -] +def not_connected(fn): + @functools.wraps(fn) + def if_not_connected(inst, *args, **kwargs): + if inst._connected: + raise HopsworksConnectionError + return fn(inst, *args, **kwargs) + + return if_not_connected + + +def connected(fn): + @functools.wraps(fn) + def if_connected(inst, *args, **kwargs): + if not inst._connected: + raise NoHopsworksConnectionError + return fn(inst, *args, **kwargs) + + return if_connected + + +class HopsworksConnectionError(Exception): + """Thrown when attempted to change connection attributes while connected.""" + + def __init__(self): + super().__init__( + "Connection is currently in use. Needs to be closed for modification." + ) + + +class NoHopsworksConnectionError(Exception): + """Thrown when attempted to perform operation on connection while not connected.""" + + def __init__(self): + super().__init__( + "Connection is not active. Needs to be connected for feature store operations." + ) + + +if os.environ.get("HOPSWORKS_RUN_WITH_TYPECHECK", False): + from typeguard import typechecked +else: + from typing import TypeVar + + _T = TypeVar("_T") + + def typechecked( + target: _T, + ) -> _T: + return target if target else typechecked + + +def uses_great_expectations(f): + @functools.wraps(f) + def g(*args, **kwds): + if not HAS_GREAT_EXPECTATIONS: + raise ModuleNotFoundError(great_expectations_not_installed_message) + return f(*args, **kwds) + + return g diff --git a/python/hsml/decorators.py b/python/hsml/decorators.py index 1165a2daa..826fd5aa2 100644 --- a/python/hsml/decorators.py +++ b/python/hsml/decorators.py @@ -1,5 +1,5 @@ # -# Copyright 2024 Hopsworks AB +# Copyright 2021 Logical Clocks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,21 +14,42 @@ # limitations under the License. # -from hopsworks_common.decorators import ( - HopsworksConnectionError, - NoHopsworksConnectionError, - connected, - not_connected, - typechecked, - uses_great_expectations, -) - - -__all__ = [ - HopsworksConnectionError, - NoHopsworksConnectionError, - connected, - not_connected, - typechecked, - uses_great_expectations, -] +import functools + + +def not_connected(fn): + @functools.wraps(fn) + def if_not_connected(inst, *args, **kwargs): + if inst._connected: + raise HopsworksConnectionError + return fn(inst, *args, **kwargs) + + return if_not_connected + + +def connected(fn): + @functools.wraps(fn) + def if_connected(inst, *args, **kwargs): + if not inst._connected: + raise NoHopsworksConnectionError + return fn(inst, *args, **kwargs) + + return if_connected + + +class HopsworksConnectionError(Exception): + """Thrown when attempted to change connection attributes while connected.""" + + def __init__(self): + super().__init__( + "Connection is currently in use. Needs to be closed for modification." + ) + + +class NoHopsworksConnectionError(Exception): + """Thrown when attempted to perform operation on connection while not connected.""" + + def __init__(self): + super().__init__( + "Connection is not active. Needs to be connected for model registry operations." + ) diff --git a/python/tests/core/test_online_store_rest_client.py b/python/tests/core/test_online_store_rest_client.py index 39ed1f640..90d368dfd 100644 --- a/python/tests/core/test_online_store_rest_client.py +++ b/python/tests/core/test_online_store_rest_client.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import hopsworks_common +import hsfs import pytest from furl import furl -from hopsworks_common.client import auth, exceptions, online_store_rest_client +from hsfs.client import auth, exceptions, online_store_rest_client class MockExternalClient: @@ -50,15 +50,13 @@ def test_setup_rest_client_external(self, mocker, monkeypatch): def client_get_instance(): return MockExternalClient() - monkeypatch.setattr( - hopsworks_common.client, "get_instance", client_get_instance - ) + monkeypatch.setattr(hsfs.client, "get_instance", client_get_instance) variable_api_mock = mocker.patch( - "hopsworks_common.core.variable_api.VariableApi.get_loadbalancer_external_domain", + "hsfs.core.variable_api.VariableApi.get_loadbalancer_external_domain", return_value="app.hopsworks.ai", ) ping_rdrs_mock = mocker.patch( - "hopsworks_common.client.online_store_rest_client.OnlineStoreRestClientSingleton.is_connected", + "hsfs.client.online_store_rest_client.OnlineStoreRestClientSingleton.is_connected", ) # Act @@ -88,16 +86,14 @@ def test_setup_online_store_rest_client_internal(self, mocker, monkeypatch): def client_get_instance(): return MockInternalClient() - monkeypatch.setattr( - hopsworks_common.client, "get_instance", client_get_instance - ) + monkeypatch.setattr(hsfs.client, "get_instance", client_get_instance) variable_api_mock = mocker.patch( - "hopsworks_common.core.variable_api.VariableApi.get_service_discovery_domain", + "hsfs.core.variable_api.VariableApi.get_service_discovery_domain", return_value="consul", ) optional_config = {"api_key": "provided_api_key"} ping_rdrs_mock = mocker.patch( - "hopsworks_common.client.online_store_rest_client.OnlineStoreRestClientSingleton.is_connected", + "hsfs.client.online_store_rest_client.OnlineStoreRestClientSingleton.is_connected", ) # Act