diff --git a/python/hopsworks_common/connection.py b/python/hopsworks_common/connection.py index ea954b79a..552fbd132 100644 --- a/python/hopsworks_common/connection.py +++ b/python/hopsworks_common/connection.py @@ -24,7 +24,7 @@ import weakref from typing import Any, Optional -from hopsworks_common import client, usage, util, version +from hopsworks_common import client, constants, usage, util, version from hopsworks_common.core import ( hosts_api, project_api, @@ -99,8 +99,8 @@ class Connection: defaults to the project from where the client is run from. Defaults to `None`. engine: Specifies the engine to use. Possible options are "spark", "python", "training", "spark-no-metastore", or "spark-delta". The default value is None, which automatically selects the engine based on the environment: - "spark": Used if Spark is available, such as in Hopsworks or Databricks environments. - "python": Used in local Python environments or AWS SageMaker when Spark is not available. + "spark": Used if Spark is available and the connection is not to serverless Hopsworks, such as in Hopsworks or Databricks environments. + "python": Used in local Python environments or AWS SageMaker when Spark is not available or the connection is done to serverless Hopsworks. "training": Used when only feature store metadata is needed, such as for obtaining training dataset locations and label information during Hopsworks training experiments. "spark-no-metastore": Functions like "spark" but does not rely on the Hive metastore. "spark-delta": Minimizes dependencies further by avoiding both Hive metastore and HopsFS. @@ -337,26 +337,26 @@ def connect(self) -> None: self._connected = True finalizer = weakref.finalize(self, self.close) try: + internal = client.base.Client.REST_ENDPOINT in os.environ + serverless = self._host != constants.HOSTS.APP_HOST # determine engine, needed to init client - if (self._engine is not None and self._engine.lower() == "spark") or ( - self._engine is None and importlib.util.find_spec("pyspark") + if ( + self._engine is None + and importlib.util.find_spec("pyspark") + and (internal or serverless) ): self._engine = "spark" - elif (self._engine is not None and self._engine.lower() == "python") or ( - self._engine is None and not importlib.util.find_spec("pyspark") - ): + elif self._engine is None: + self._engine = "python" + elif self._engine.lower() == "spark": + self._engine = "spark" + elif self._engine.lower() == "python": self._engine = "python" - elif self._engine is not None and self._engine.lower() == "training": + elif self._engine.lower() == "training": self._engine = "training" - elif ( - self._engine is not None - and self._engine.lower() == "spark-no-metastore" - ): + elif self._engine.lower() == "spark-no-metastore": self._engine = "spark-no-metastore" - elif ( - self._engine is not None - and self._engine.lower() == "spark-delta" - ): + elif self._engine.lower() == "spark-delta": self._engine = "spark-delta" else: raise ConnectionError( @@ -365,7 +365,12 @@ def connect(self) -> None: ) # init client - if client.base.Client.REST_ENDPOINT not in os.environ: + if internal: + client.init( + "hopsworks", + hostname_verification=self._hostname_verification, + ) + else: client.init( "external", self._host, @@ -378,11 +383,6 @@ def connect(self) -> None: self._api_key_file, self._api_key_value, ) - else: - client.init( - "hopsworks", - hostname_verification=self._hostname_verification, - ) client.set_connection(self) from hsfs.core import feature_store_api