I have a Spark DataFrame consisting of three columns:
 id | col1 | col2 
-----------------
 x  |  p1  |  a1  
-----------------
 x  |  p2  |  b1
-----------------
 y  |  p2  |  b2
-----------------
 y  |  p2  |  b3
-----------------
 y  |  p3  |  c1
After applying df.groupBy("id").pivot("col1").agg(collect_list("col2")) I am getting the following dataframe (aggDF):
+---+----+--------+----+
| id|  p1|      p2|  p3|
+---+----+--------+----+
|  x|[a1]|    [b1]|  []|
|  y|  []|[b2, b3]|[c1]|
+---+----+--------+----+
Then I find the name of columns except the id column.
val cols = aggDF.columns.filter(x => x != "id")
After that I am using cols.foldLeft(aggDF)((df, x) => df.withColumn(x, when(size(col(x)) > 0, col(x)).otherwise(lit(null)))) to replace empty array with null. The performance of this code becomes poor when the number of columns increases. Additionally, I have the name of string columns val stringColumns = Array("p1","p3"). I want to get the following final dataframe:
+---+----+--------+----+
| id|  p1|      p2|  p3|
+---+----+--------+----+
|  x| a1 |    [b1]|null|
|  y|null|[b2, b3]| c1 |
+---+----+--------+----+
Is there any better solution to this problem in order to achieve the final dataframe?
 
     
    