from pyspark.sql import SparkSession
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml import Pipeline
import numpy as np
# Create a Spark session
spark = SparkSession.builder.appName("MLlibExample").getOrCreate()
# Generate a toy dataset for illustration
np.random.seed(42)
num_samples = 1000
# Features: number of bedrooms, square footage
data = [(np.random.randint(1, 5), 100 + 50 * np.random.rand(), 150 + 75 * np.random.randint(1, 5) + 0.1 * (100 + 50 * np.random.rand()) + 10 * np.random.randn())
for _ in range(num_samples)]
# Create a DataFrame
df = spark.createDataFrame(data, ["bedrooms", "square_footage", "price"])
# Create a feature vector
feature_cols = ["bedrooms", "square_footage"]
vector_assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
df = vector_assembler.transform(df)
# Split the data into training and testing sets
(train_data, test_data) = df.randomSplit([0.8, 0.2], seed=42)
# Build a Linear Regression model
lr = LinearRegression(featuresCol="features", labelCol="price")
# Create a pipeline
pipeline = Pipeline(stages=[vector_assembler, lr])
# Train the model
model = pipeline.fit(train_data) ## Fails at this line
# Make predictions on the test set
predictions = model.transform(test_data)
# Evaluate the model
evaluator = RegressionEvaluator(labelCol="price", predictionCol="prediction", metricName="mse")
mse = evaluator.evaluate(predictions)
print(f"Mean Squared Error on Test Set: {mse}")
========
IllegalArgumentException Traceback (most recent call last) File <command-814210928066392>:38 35 # Train the model 36 model = pipeline.fit(train_data) ---> 38 # Make predictions on the test set 39 predictions = model.transform(test_data) 41 # Evaluate the model File /databricks/python_shell/dbruntime/MLWorkloadsInstrumentation/_pyspark.py:30, in _create_patch_function.<locals>.patched_method(self, *args, **kwargs) 28 call_succeeded = False 29 try: ---> 30 result = original_method(self, *args, **kwargs) 31 call_succeeded = True 32 return result File /databricks/spark/python/pyspark/ml/base.py:205, in Estimator.fit(self, dataset, params) 203 return self.copy(params)._fit(dataset) 204 else: --> 205 return self._fit(dataset) 206 else: 207 raise TypeError( 208 "Params must be either a param map or a list/tuple of param maps, " 209 "but got %s." % type(params) 210 ) File /databricks/spark/python/pyspark/ml/pipeline.py:132, in Pipeline._fit(self, dataset) 130 if isinstance(stage, Transformer): 131 transformers.append(stage) --> 132 dataset = stage.transform(dataset) 133 else: # must be an Estimator 134 model = stage.fit(dataset) File /databricks/spark/python/pyspark/ml/base.py:262, in Transformer.transform(self, dataset, params) 260 return self.copy(params)._transform(dataset) 261 else: --> 262 return self._transform(dataset) 263 else: 264 raise TypeError("Params must be a param map but got %s." % type(params)) File /databricks/spark/python/pyspark/ml/wrapper.py:400, in JavaTransformer._transform(self, dataset) 397 assert self._java_obj is not None 399 self._transfer_params_to_java() --> 400 return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sparkSession) File /databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/java_gateway.py:1321, in JavaMember.__call__(self, *args) 1315 command = proto.CALL_COMMAND_NAME +\ 1316 self.command_header +\ 1317 args_command +\ 1318 proto.END_COMMAND_PART 1320 answer = self.gateway_client.send_command(command) -> 1321 return_value = get_return_value( 1322 answer, self.gateway_client, self.target_id, self.name) 1324 for temp_arg in temp_args: 1325 temp_arg._detach() File /databricks/spark/python/pyspark/errors/exceptions.py:234, in capture_sql_exception.<locals>.deco(*a, **kw) 230 converted = convert_exception(e.java_exception) 231 if not isinstance(converted, UnknownException): 232 # Hide where the exception came from that shows a non-Pythonic 233 # JVM exception message. --> 234 raise converted from None 235 else: 236 raise IllegalArgumentException: Output column features already exists.