PySpark pandas_udf slower than single thread

twotwoiscute
New Contributor

I used

@pandas_udf
write a function for speeding up the process(parsing xml file ) and then compare it's speed with single thread , Surprisingly , Using
@pandas_udf
is two times slower than single-thread code. And the number of xml files I need to parse is around 20000. The code below shows exactly what I did :

spark = SparkSession.builder.appName("EDA").getOrCreate()
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "64")
@pandas_udf(ArrayType(ArrayType(IntegerType())))
def parse_xml(xml_names : pd.Series) -> pd.Series : 
    results = []
    for xml_name in xml_names:
        tree = ET.parse(xml_name)
        root = tree.getroot()
        keep_boxes = []
        for obj in root.iter("object"):
            class_id = int( obj.find("name").text )
            boxes = obj.find("bndbox")
            xmin = int(boxes.find('xmin').text)
            ymin = int(boxes.find('ymin').text)
            xmax = int(boxes.find('xmax').text)
            ymax = int(boxes.find('ymax').text)
            keep_boxes.append([ class_id , xmin , ymin , xmax , ymax])
        results.append(keep_boxes)
    return pd.Series(results)
#collect all data from different folders 
datas = np.array(get_data()).T.tolist()
schema = StructType([
         StructField('img_name', StringType(), True),
         StructField('xml_name', StringType(), True)])
num_cores = 20 #(number of cores I have)
muls = 3
df = spark.createDataFrame(datas,schema).repartition(muls*num_cores)
pdf_box = df.select(col("img_name"),parse_xml(col('xml_name')).alias("boxes")).toPandas()

As far as I know, since I use for loop so advantage of

pandas_udf
would be gone since it can't really process whole batch at once , However, I still expect that it should be faster than single thread since
Spark
breaks data into parititons and process them parallelly.If the concept that I said above is wrong please correct me.

So I would like to know the reason why it's even slower than single-thread code.Is it because the code I wrote or some important idea that I jsut miss. Thanks!