cancel
Showing results for 
Search instead for 
Did you mean: 
Machine Learning
Dive into the world of machine learning on the Databricks platform. Explore discussions on algorithms, model training, deployment, and more. Connect with ML enthusiasts and experts.
cancel
Showing results for 
Search instead for 
Did you mean: 

UDF LLM DataBrick pickle error

llmnerd
New Contributor

Hi there,

I am trying to parellize a text extraction via the Databrick foundational model.

Any pointers to suggestions or examples are welcome

The code and error below.

model = "databricks-meta-llama-3-1-70b-instruct"
temperature=0.0
max_tokens=1024

schema_llm = StructType([
    StructField("contains_vulnerability", BooleanType(), True),
])

chat_model = ChatDatabricks(
            endpoint=model,
            temperature=temperature,
            max_tokens=max_tokens
        )

chain_llm: LLMChain = (chat_prompt | chat_model.with_structured_output(VulnerabilityReport))

@udf(returnType=schema_llm) 
def CheckContent(text:str): 
    out = chain_llm.invoke({"content":text})
    return (out["contains_vulnerability"])
    
expand_df = sample_df.withColumn("content_check", CheckContent("file_content"))
display(expand_df)<div><span>And I am getting a pickle error:<div> <li-code lang="markup">Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/serializers.py", line 559, in dumps
    return cloudpickle.dumps(obj, pickle_protocol)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/databricks/spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 632, in dump
    return Pickler.dump(self, obj)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/core/context.py", line 525, in __getnewargs__
    raise PySparkRuntimeError(
pyspark.errors.exceptions.base.PySparkRuntimeError: [CONTEXT_ONLY_VALID_ON_DRIVER] It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

 

1 REPLY 1

Ayushi_Suthar
Databricks Employee
Databricks Employee

Hi @llmnerd , Hope you are doing well! 

Upon reviewing the details provided, we have identified several observations regarding the SparkContext serialization error encountered. Please find a detailed analysis and our recommendations below:
==== ANALYSIS ====
Error Encountered: An error occurred indicating that the SparkContext object could not be serialized. This typically occurs when SparkContext is referenced from a broadcast variable, action, or transformation, which is only permissible on the driver and not on the worker nodes.
Analysis of the Problematic Code:
1.Broadcast Variable Initialization: broadcast_var = spark.sparkContext.broadcast([cloudTrailSchema, parquetOutputPath])
This line attempts to broadcast cloudTrailSchema and parquetOutputPath to all worker nodes, which is a valid approach for making configuration data available cluster-wide.

2. RDD Creation: rdd = spark.sparkContext.parallelize([cloudTrailSchema, parquetOutputPath])
Here, the intent seems to be to distribute these objects for parallel processing, which is conceptually incorrect. Instead, creating an RDD of actual file paths would be appropriate: rdd = spark.sparkContext.parallelize(file_paths)

3.Data Processing: result = rdd.mapPartitions(process_partition).collect()
This line processes the RDD created in line 2 using the mapPartitions method, which is intended to apply a function to each partition of the RDD. The function process_partition attempts to process data using the broadcast variables.

Proposed Correction:
def process_partition(iterator):
broadcasted_values = broadcast_var.value
   schema, output_path = broadcasted_values
   for file_path in iterator:
     process_file(file_path, schema, output_path)

This function calls process_file in which the sparkContext is being used:
df = spark.read.schema(cloudTrailSchema).json(file_path)

This is not a valid approach as SparkContext can only be used on the driver node and cannot be serialized or accessed on the worker nodes.

== Root Cause Analysis ==
Referencing SparkContext within actions or transformations leads to serialization errors, as these operations execute on worker nodes where SparkContext is unavailable.

===== Solution ======
Revise the process_file function to avoid SparkContext access on the workers. Consider using Python’s ThreadPool Executor for achieving concurrency, which does not involve SparkContext operations on worker nodes.

Please let me know if this helps and leave a like if this information is useful, followups are appreciated.

Kudos

Ayushi

Connect with Databricks Users in Your Area

Join a Regional User Group to connect with local Databricks users. Events will be happening in your city, and you won’t want to miss the chance to attend and share knowledge.

If there isn’t a group near you, start one and help create a community that brings people together.

Request a New Group