I want to calculate the Jaro Winkler distance between two columns of a PySpark DataFrame. Jaro Winkler distance is available through pyjarowinkler package on all nodes.
pyjarowinkler works as follows:
from pyjarowinkler import distance
distance.get_jaro_distance("A", "A", winkler=True, scaling=0.1)
Output:
1.0
I am trying to write a Pandas UDF to pass two columns as Series and calculate the distance using lambda function. Here's how I am doing it:
@pandas_udf("float", PandasUDFType.SCALAR)
def get_distance(col1, col2):
    import pandas as pd
    distance_df  = pd.DataFrame({'column_A': col1, 'column_B': col2})
    distance_df['distance'] = distance_df.apply(lambda x: distance.get_jaro_distance(str(distance_df['column_A']), str(distance_df['column_B']), winkler = True, scaling = 0.1))
    return distance_df['distance']
temp = temp.withColumn('jaro_distance', get_distance(temp.x, temp.x))
I should be able to pass any two string columns in the above function. I am getting the following output:
+---+---+---+-------------+
|  x|  y|  z|jaro_distance|
+---+---+---+-------------+
|  A|  1|  2|         null|
|  B|  3|  4|         null|
|  C|  5|  6|         null|
|  D|  7|  8|         null|
+---+---+---+-------------+
Expected Output:
+---+---+---+-------------+
|  x|  y|  z|jaro_distance|
+---+---+---+-------------+
|  A|  1|  2|          1.0|
|  B|  3|  4|          1.0|
|  C|  5|  6|          1.0|
|  D|  7|  8|          1.0|
+---+---+---+-------------+
I suspect this might be because str(distance_df['column_A']) is not correct. It contains the concatenated string of all row values. 
While this code works for me:
@pandas_udf("float", PandasUDFType.SCALAR)
def get_distance(col):
    return col.apply(lambda x: distance.get_jaro_distance(x, "A", winkler = True, scaling = 0.1))
temp = temp.withColumn('jaro_distance', get_distance(temp.x))
Output:
+---+---+---+-------------+
|  x|  y|  z|jaro_distance|
+---+---+---+-------------+
|  A|  1|  2|          1.0|
|  B|  3|  4|          0.0|
|  C|  5|  6|          0.0|
|  D|  7|  8|          0.0|
+---+---+---+-------------+
Is there a way to do this with Pandas UDF? I'm dealing with millions of records so UDF will be expensive but still acceptable if it works. Thanks.
 
     
    