there is a summary() api inside dataset which computes basicStats in the below format-
    ds.summary("count", "min", "25%", "75%", "max").show()
   
    // output:
    // summary age   height
    // count   10.0  10.0
    // min     18.0  163.0
    // 25%     24.0  176.0
    // 75%     32.0  180.0
    // max     92.0  192.0
Similarly, You can enrich the dataframe apis to get the stats in the format you required as below-
Define RichDataframe & implicits to use
 import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{NumericType, StringType, StructField, StructType}
import scala.language.implicitConversions
class RichDataFrame(ds: DataFrame) {
  def statSummary(statistics: String*): DataFrame = {
    val defaultStatistics = Seq("max", "min", "mean", "std", "skewness", "kurtosis")
    val statFunctions = if (statistics.nonEmpty) statistics else defaultStatistics
    val selectedCols = ds.schema
      .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType])
      .map(_.name)
    val percentiles = statFunctions.filter(a => a.endsWith("%")).map { p =>
      try {
        p.stripSuffix("%").toDouble / 100.0
      } catch {
        case e: NumberFormatException =>
          throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
      }
    }
    require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
    val aggExprs = selectedCols.flatMap(c => {
      var percentileIndex = 0
      statFunctions.map { stats =>
        if (stats.endsWith("%")) {
          val index = percentileIndex
          percentileIndex += 1
          expr(s"cast(percentile_approx($c, array(${percentiles.mkString(", ")}))[$index] as string)")
        } else {
          expr(s"cast($stats($c) as string)")
        }
      }
    })
    val aggResult = ds.select(aggExprs: _*).head()
    val r = aggResult.toSeq.grouped(statFunctions.length).toArray
      .zip(selectedCols)
      .map{case(seq, column) => column +: seq }
      .map(Row.fromSeq)
    val output = StructField("columns", StringType) +: statFunctions.map(c => StructField(c, StringType))
    val spark = ds.sparkSession
    spark.createDataFrame(spark.sparkContext.parallelize(r), StructType(output))
  }
}
object RichDataFrame {
  trait Enrichment {
    implicit def enrichMetadata(ds: DataFrame): RichDataFrame =
      new RichDataFrame(ds)
  }
  object implicits extends Enrichment
}
Test with the provided test data as below
     val df = Seq(
      (10, 20, 30, 40, 50),
      (100, 200, 300, 400, 500),
      (111, 222, 333, 444, 555),
      (1123, 2123, 3123, 4123, 5123),
      (1321, 2321, 3321, 4321, 5321)
    ).toDF("col_1", "col_2", "col_3", "col_4", "col_5")
    val columnsToCalculate = Seq("col_2","col_3","col_4")
    import com.som.spark.shared.RichDataFrame.implicits._
    df.selectExpr(columnsToCalculate: _*)
      .statSummary("mean", "count", "25%", "75%", "90%")
      .show(false)
    /**
      * +-------+------+-----+---+----+----+
      * |columns|mean  |count|25%|75% |90% |
      * +-------+------+-----+---+----+----+
      * |col_2  |977.2 |5    |200|2123|2321|
      * |col_3  |1421.4|5    |300|3123|3321|
      * |col_4  |1865.6|5    |400|4123|4321|
      * +-------+------+-----+---+----+----+
      */