I am using the groupBy function to remove duplicates from a spark DataFrame. For each group I simply want to take the first row, which will be the most recent one. 
I don't want to perform a max() aggregation because I know the results are already stored sorted in Cassandra and want to avoid unnecessary computation. See this approach using pandas, its exactly what I'm after except in Spark.
df = sqlContext.read\
            .format("org.apache.spark.sql.cassandra")\
            .options(table="table", keyspace="keyspace")\
            .load()\
            .groupBy("key")\
            #what goes here?
 
     
    