I've created the following graph:
spark = SparkSession.builder.appName('aggregate').getOrCreate()
vertices = spark.createDataFrame([('1', 'foo', 99),
('2', 'bar', 10),
('3', 'baz', 25),
('4', 'spam', 7)],
['id', 'name', 'value'])
edges = spark.createDataFrame([('1', '2'),
('1', '3'),
('3', '4')],
['src', 'dst'])
g = GraphFrame(vertices, edges)
I would like to aggregate the messages, such that for any given vertex we have a list of all values for its children vertices all the way to the edge. For example, from vertex 1 we have a child edge to vertex 3 which has a child edge to vertex 4. We also have a child edge to 2. That is:
(1) --> (3) --> (4)
\
\--> (2)
From 1 I'd like to collect all values from this path: [99, 10, 25, 7]. Where 99 is the value for vertex 1, 10 is the value of the child vertex 2, 25 is the value at vertex 3 and 7 is the value at vertex 4.
From 3 we'd have the values [25, 7], etc.
I can approximate this with aggregateMessages:
agg = g.aggregateMessages(collect_list(AM.msg).alias('allValues'),
sendToSrc=AM.dst['value'],
sendToDst=None)
agg.show()
Which produces:
+---+---------+
| id|allValues|
+---+---------+
| 3| [7]|
| 1| [25, 10]|
+---+---------+
At 1 we have [25, 10] which are the immediate child values, but we are missing 7 and the "self" value of 99.
Similarly, I'm missing 25 for vertex 3.
How can I aggregate messages "recursively", such that allValues from child vertices are aggregated at the parent?