I would like to collapse the rows in a dataframe based on an ID column and count the number of records per ID using window functions. Doing this, I would like to avoid partitioning the window by ID, because this would result in a very large number of partitions.
I have a dataframe of the form
+----+-----------+-----------+-----------+
| ID | timestamp | metadata1 | metadata2 |
+----+-----------+-----------+-----------+
|  1 | 09:00     | ABC       | apple     |
|  1 | 08:00     | NULL      | NULL      |
|  1 | 18:00     | XYZ       | apple     |
|  2 | 07:00     | NULL      | banana    |
|  5 | 23:00     | ABC       | cherry    |
+----+-----------+-----------+-----------+
where I would like to keep only the records with the most recent timestamp per ID, such that I have
+----+-----------+-----------+-----------+-------+
| ID | timestamp | metadata1 | metadata2 | count |
+----+-----------+-----------+-----------+-------+
|  1 | 18:00     | XYZ       | apple     |     3 |
|  2 | 07:00     | NULL      | banana    |     1 |
|  5 | 23:00     | ABC       | cherry    |     1 |
+----+-----------+-----------+-----------+-------+
I have tried:
window = Window.orderBy( [asc('ID'), desc('timestamp')] )
window_count = Window.orderBy( [asc('ID'), desc('timestamp')] ).rowsBetween(-sys.maxsize,sys.maxsize)
columns_metadata = [metadata1, metadata2]
df = df.select(
              *(first(col_name, ignorenulls=True).over(window).alias(col_name) for col_name in columns_metadata),
              count(col('ID')).over(window_count).alias('count')
              )
df = df.withColumn("row_tmp", row_number().over(window)).filter(col('row_tmp') == 1).drop(col('row_tmp'))
which is in part based on How to select the first row of each group?
This without the use of pyspark.sql.Window.partitionBy, this does not give the desired output.
 
    