Hi there,
I need some help with this example. We're trying to create a linearRegression model that can parallelize for thousands of symbols per date. When we run this we get a picklingError
Any suggestions would be much appreciated!
PicklingError: Could not serialize object: RuntimeError: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.
from pyspark.sql import SparkSession
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler
# Create a SparkSession
spark = SparkSession.builder.getOrCreate()
# Create an RDD with your data
data_rdd = spark.sparkContext.parallelize([
("symbol1", 1, 2, 3),
("symbol2", 4, 5, 6),
("symbol3", 7, 8, 9)
# Convert the RDD to a DataFrame
data_df = data_rdd.toDF(["Symbol", "Feature1", "Feature2", "Feature3"])
# Define the features column
assembler = VectorAssembler(inputCols=["Feature1", "Feature2", "Feature3"], outputCol="features")
# Fit models on each partition and collect the weights
def fit_model(partition):
# Create a new linear regression model
model = LinearRegression(featuresCol="features", labelCol="Symbol")
# Create an empty list to store the weights
weights = []
# Convert the partition iterator to a list
data_list = list(partition)
# Convert the list to a DataFrame
data_partition_df = spark.createDataFrame(data_list, data_df.columns)
# Perform vector assembly
data_partition_df = assembler.transform(data_partition_df)
# Fit the model on the partition data
fitted_model = model.fit(data_partition_df)
# Get the model weights
weights = [fitted_model.coefficients[i] for i in range(len(fitted_model.coefficients))]
# Yield the weights
yield weights
# Fit models on each partition and collect the weights
partition_weights = data_df.rdd.mapPartitions(fit_model).collect()
# Create a DataFrame with the collected weights
weights_df = spark.createDataFrame(partition_weights, ["Weight1", "Weight2", "Weight3"])
# Show the weights DataFrame