scala - how to interpret RDD.treeAggregate

i ran this line in apache spark code source

val (gradientsum, losssum, minibatchsize) = data     .sample(false, minibatchfraction, 42 + i)     .treeaggregate((bdv.zeros[double](n), 0.0, 0l))(       seqop = (c, v) => {         // c: (grad, loss, count), v: (label, features)         val l = gradient.compute(v._2, v._1, bcweights.value, vectors.frombreeze(c._1))         (c._1, c._2 + l, c._3 + 1)       },       combop = (c1, c2) => {         // c: (grad, loss, count)         (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)       }     ) 

i have multiple trouble reading :

  • first can't find on web explains how treeaggregate works, meaning of params.
  • second, here .treeaggregate seems have 2 ()() following method name. mean? special scala syntax don't understand.
  • finally, see both seqop , comboop return 3 element tuple match expected left hand side variable, 1 gets returned?

this statement must advanced. can't begin decipher this.

treeaggregate specialized implementation of aggregate iteratively applies combine function subset of partitions. done in order prevent returning partial results driver single pass reduce take place classic aggregate does.

for practical purposes, treeaggregate follows same principle aggregate explained in answer: explain aggregate functionality in python exception takes parameter indicate depth of partial aggregation level.

let me try explain what's going on here specifically:

for aggregate, need zero, combiner function , reduce function. aggregate uses currying specify 0 value independently of combine , reduce functions.

we can dissect above function . helps understanding:

val zero: (bdv, double, long) = (bdv.zeros[double](n), 0.0, 0l) val combinerfunction: ((bdv, double, long), (??, ??)) => (bdv, double, long)  =  (c, v) => {         // c: (grad, loss, count), v: (label, features)         val l = gradient.compute(v._2, v._1, bcweights.value, vectors.frombreeze(c._1))         (c._1, c._2 + l, c._3 + 1) val reducerfunction: ((bdv, double, long),(bdv, double, long)) => (bdv, double, long) = (c1, c2) => {         // c: (grad, loss, count)         (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)       } 

then can rewrite call treeaggregate in more digestable form:

val (gradientsum, losssum, minibatchsize) = treeaggregate(zero)(combinerfunction, reducerfunction) 

this form 'extract' resulting tuple named values gradientsum, losssum, minibatchsize further usage.

note treeaggregate takes additional parameter depth declared default value depth = 2, thus, it's not provided in particular call, take default value.


