Databricks connect and spark session best practice

thibault
Contributor III

Hi all!

I am using databricks-connect to develop and test pyspark code pure python (not notebook) files in my local IDE, running on a Databricks cluster. These files are part of a deployment setup with dbx so that they are run as tasks in a workflow.

Everything works fine, but then there is this piece, to know whether to create a databricks connect spark session or reuse the spark session running in Databricks as part of a job :

try:
    from databricks.connect import DatabricksSession
    spark = DatabricksSession.builder.getOrCreate()
except ImportError:
    from pyspark.sql import SparkSession
    spark = SparkSession.builder.getOrCreate()

And that feels like code smell. Is there a nicer way you would recommend to handle the spark session, whether running locally via databricks-connect or directly on Databricks?

thibault
Contributor III

Thanks for your response @Retired_mod . Can you elaborate on the difference between your suggestion and the code I provided? i.e. what would your if-else look like?

bartoszmalec
New Contributor II

By using the following command:

sc.version

you can refer to Spark version and apply Control Flow - as since version 3.4 its available, you can apply custom logic to consider the version of Spark and Control whether to use Spark Connect or Spark Session.

FedeRaimondi
Contributor II

I personally do something similar by checking an environment variable, the example was for a notebook but should work for a python file as well:

 

import os

if not os.environ.get("DATABRICKS_RUNTIME_VERSION"):
    from databricks.connect import DatabricksSession
    print("This notebook is running outside of Databricks. Using Databricks Connect...")
    spark = DatabricksSession.builder.getOrCreate()
    print("Runtime:", spark.conf.get("spark.databricks.clusterUsageTags.sparkVersion"))
else:
    print(
        "This notebook is running inside Databricks. Using the default Spark session...\nRuntime:",
        spark.conf.get("spark.databricks.clusterUsageTags.sparkVersion"),
    )