Skip to content

Commit

Permalink
Fix engine choice in case of connection to serverless
Browse files Browse the repository at this point in the history
  • Loading branch information
aversey committed Dec 5, 2024
1 parent 47be364 commit 58445f6
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions python/hopsworks_common/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import weakref
from typing import Any, Optional

from hopsworks_common import client, usage, util, version
from hopsworks_common import client, constants, usage, util, version
from hopsworks_common.core import (
hosts_api,
project_api,
Expand Down Expand Up @@ -99,8 +99,8 @@ class Connection:
defaults to the project from where the client is run from.
Defaults to `None`.
engine: Specifies the engine to use. Possible options are "spark", "python", "training", "spark-no-metastore", or "spark-delta". The default value is None, which automatically selects the engine based on the environment:
"spark": Used if Spark is available, such as in Hopsworks or Databricks environments.
"python": Used in local Python environments or AWS SageMaker when Spark is not available.
"spark": Used if Spark is available and the connection is not to serverless Hopsworks, such as in Hopsworks or Databricks environments.
"python": Used in local Python environments or AWS SageMaker when Spark is not available or the connection is done to serverless Hopsworks.
"training": Used when only feature store metadata is needed, such as for obtaining training dataset locations and label information during Hopsworks training experiments.
"spark-no-metastore": Functions like "spark" but does not rely on the Hive metastore.
"spark-delta": Minimizes dependencies further by avoiding both Hive metastore and HopsFS.
Expand Down Expand Up @@ -337,26 +337,26 @@ def connect(self) -> None:
self._connected = True
finalizer = weakref.finalize(self, self.close)
try:
internal = client.base.Client.REST_ENDPOINT in os.environ
serverless = self._host != constants.HOSTS.APP_HOST
# determine engine, needed to init client
if (self._engine is not None and self._engine.lower() == "spark") or (
self._engine is None and importlib.util.find_spec("pyspark")
if (
self._engine is None
and importlib.util.find_spec("pyspark")
and (internal or serverless)
):
self._engine = "spark"
elif (self._engine is not None and self._engine.lower() == "python") or (
self._engine is None and not importlib.util.find_spec("pyspark")
):
elif self._engine is None:
self._engine = "python"
elif self._engine.lower() == "spark":
self._engine = "spark"
elif self._engine.lower() == "python":
self._engine = "python"
elif self._engine is not None and self._engine.lower() == "training":
elif self._engine.lower() == "training":
self._engine = "training"
elif (
self._engine is not None
and self._engine.lower() == "spark-no-metastore"
):
elif self._engine.lower() == "spark-no-metastore":
self._engine = "spark-no-metastore"
elif (
self._engine is not None
and self._engine.lower() == "spark-delta"
):
elif self._engine.lower() == "spark-delta":
self._engine = "spark-delta"
else:
raise ConnectionError(
Expand All @@ -365,7 +365,12 @@ def connect(self) -> None:
)

# init client
if client.base.Client.REST_ENDPOINT not in os.environ:
if internal:
client.init(
"hopsworks",
hostname_verification=self._hostname_verification,
)
else:
client.init(
"external",
self._host,
Expand All @@ -378,11 +383,6 @@ def connect(self) -> None:
self._api_key_file,
self._api_key_value,
)
else:
client.init(
"hopsworks",
hostname_verification=self._hostname_verification,
)

client.set_connection(self)
from hsfs.core import feature_store_api
Expand Down

0 comments on commit 58445f6

Please sign in to comment.