From e63ff1cafaab3f11a4aa87a72e53e646f03b2ae2 Mon Sep 17 00:00:00 2001 From: tanglizhe1105 Date: Sat, 14 May 2016 17:16:26 +0800 Subject: [PATCH 1/2] fix bug for issue #13 --- .../java/com/intel/distml/util/KeyHash.java | 4 ++++ .../java/com/intel/distml/util/KeyRange.java | 22 +++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/main/java/com/intel/distml/util/KeyHash.java b/src/main/java/com/intel/distml/util/KeyHash.java index 79f4bc1..48a75f9 100644 --- a/src/main/java/com/intel/distml/util/KeyHash.java +++ b/src/main/java/com/intel/distml/util/KeyHash.java @@ -26,6 +26,10 @@ public class KeyHash extends KeyCollection { public KeyHash(int hashQuato, int hashIndex, long minKey, long maxKey) { super(KeyCollection.TYPE_HASH); + if (minKey > maxKey) { + throw new IllegalStateException("unexpected key range: (" + minKey + ", " + maxKey + ")"); + } + this.hashQuato = hashQuato; this.hashIndex = hashIndex; this.minKey = minKey; diff --git a/src/main/java/com/intel/distml/util/KeyRange.java b/src/main/java/com/intel/distml/util/KeyRange.java index d67e705..7dbb6ae 100644 --- a/src/main/java/com/intel/distml/util/KeyRange.java +++ b/src/main/java/com/intel/distml/util/KeyRange.java @@ -20,6 +20,11 @@ public class KeyRange extends KeyCollection { public KeyRange(long f, long l) { super(KeyCollection.TYPE_RANGE); + + if (f > l) { + throw new IllegalStateException("unexpected key range: (" + f + ", " + l + ")"); + } + firstKey = f; lastKey = l; } @@ -66,21 +71,24 @@ public void read(AbstractDataReader in, DataDesc format) throws Exception { } public KeyCollection[] linearSplit(int hostNum) { - KeyCollection[] sets = new KeyRange[hostNum]; + KeyCollection[] sets = new KeyCollection[hostNum]; - long start = firstKey; - long step = (lastKey - firstKey + hostNum) / hostNum; + long keySize = size(); for (int i = 0; i < hostNum; i++) { - long end = Math.min(start + step - 1, lastKey); - sets[i] = new KeyRange(start, end); - start += step; + long start = keySize * i / hostNum; + long end = keySize * (i + 1) / hostNum; + if (end == start) { + sets[i] = EMPTY; + } else { + sets[i] = new KeyRange(firstKey + start, firstKey + end - 1); + } } return sets; } public KeyCollection[] hashSplit(int hostNum) { - KeyCollection[] sets = new KeyHash[hostNum]; + KeyCollection[] sets = new KeyCollection[hostNum]; for (int i = 0; i < hostNum; i++) { sets[i] = new KeyHash(hostNum, i, firstKey, lastKey); From 723e1fc22057499a18f8e113d2949a5be87103c8 Mon Sep 17 00:00:00 2001 From: tanglizhe1105 Date: Sat, 14 May 2016 20:31:44 +0800 Subject: [PATCH 2/2] changed the index iteration's way of collection variable --- .../intel/distml/clustering/LightLDA.scala | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/src/main/scala/com/intel/distml/clustering/LightLDA.scala b/src/main/scala/com/intel/distml/clustering/LightLDA.scala index ac7fc3e..26dc53b 100644 --- a/src/main/scala/com/intel/distml/clustering/LightLDA.scala +++ b/src/main/scala/com/intel/distml/clustering/LightLDA.scala @@ -39,8 +39,8 @@ object LightLDA { println("dataset size: " + trainDataSize) dm.setTrainSetSize(trainDataSize) - data.mapPartitionsWithIndex(init(m, monitorPath, batchSize)).persist(StorageLevel.MEMORY_AND_DISK).count() - for (iter <- 0 to p.maxIterations - 1) { + data.mapPartitionsWithIndex(init(m, monitorPath, batchSize)).persist(StorageLevel.MEMORY_AND_DISK).count() //persist necessary? + for (iter <- 0 until p.maxIterations) { println("================= iteration: " + iter + " =====================") //data.mapPartitionsWithIndex(verify(p, m, monitorPath)).count() @@ -109,7 +109,7 @@ object LightLDA { var wts = wt.get(wId) if (wts == null) { wts = new Array[Integer](m.K) - for (j <- 0 to m.K - 1) { + for (j <- 0 until m.K) { wts(j) = 0 } wts(topic) = 1 @@ -177,7 +177,7 @@ object LightLDA { val dt = dtm.fetch(KeyCollection.ALL, session) val dt_old = new util.HashMap[Int, Int] - for (i <- 0 to m.K - 1) { + for (i <- 0 until m.K) { dt_old.put(i, dt.get(i)) } @@ -263,22 +263,22 @@ object LightLDA { val wtm = m.getMatrix("word-topics").asInstanceOf[IntMatrixWithIntKey] val dt = dtm.fetch(KeyCollection.ALL, session) - for (i <- 0 to m.K -1 ) { + for (i <- 0 until m.K) { println("dt(" + i + ") = " + dt(i)) } val wt : java.util.HashMap[Integer, Array[Integer]] = wtm.fetch(KeyCollection.ALL, session) - for (w <- 0 to m.V - 1) { - for (i <- 0 to m.K - 1) { + for (w <- 0 until m.V) { + for (i <- 0 until m.K) { println("wt(" + w + ")(" + i + ") = " + wt(w)(i)) } } val t_dt = new Array[Int](m.K) val t_wt = new Array[Array[Int]](m.V) - for (i <- 0 to m.V -1) + for (i <- 0 until m.V) t_wt(i) = new Array[Int](m.K) - while(it.hasNext) { + while (it.hasNext) { val (ndk, words) = it.next() val t_ndk = new Array[Int](m.K) @@ -292,19 +292,19 @@ object LightLDA { t_wt(w)(t) += 1 } - for (i <- 0 to m.K - 1) { + for (i <- 0 until m.K) { if (ndk(i) != t_ndk(i)) throw new IllegalStateException("verify failed, ndk(" + i + "): " + ndk(i) + ", " + t_ndk(i)) } } - for (i <- 0 to m.K -1 ) { + for (i <- 0 until m.K) { if (dt(i) != t_dt(i)) throw new IllegalStateException("verify failed, dt(" + i + "): " + dt(i) + ", " + t_dt(i)) } - for (w <- 0 to m.V - 1) { - for (i <- 0 to m.K - 1) { + for (w <- 0 until m.V) { + for (i <- 0 until m.K) { if (wt(w)(i) != t_wt(w)(i)) throw new IllegalStateException("verify failed, wt(" + w + ")(" + i + "): " + wt(i) + ", " + t_wt(i)) } @@ -351,22 +351,22 @@ object LightLDA { val prob = new Array[Double](m.K) for (iter <- 0 to 50) { - for (i <- 0 to samples.size() -1 ) { + for (i <- samples.indices) { val (ndk, words) = samples(i) for (n <- words.indices) { val w: Int = words(n)._1 val old: Int = words(n)._2 - for (k <- 0 to m.K-1) { + for (k <- 0 until m.K) { val theta: Double = (ndk(k) + m.alpha) / (words.length + m.alpha_sum) val phi: Double = (wt(w)(k) + m.beta) / (dt(k) + m.beta_sum) prob(k) = theta * phi } - for (k <- 1 to m.K - 1) { - prob(k) += prob(k-1) + for (k <- 1 until m.K) { + prob(k) += prob(k - 1) } val u = Math.random() * prob(m.K - 1) @@ -431,7 +431,7 @@ object LightLDA { var l: Double = 0 val w: Int = words(n)._1 - for (k <- 0 to m.K-1) { + for (k <- 0 until m.K) { val theta: Double = (ndk(k) + m.alpha) / (words.length + m.alpha_sum) val phi: Double = (nwk(w)(k) + m.beta) / (nk(k) + m.beta_sum) l += theta * phi @@ -451,8 +451,8 @@ object LightLDA { q_w_proportion_ : Array[(Int, Float)], tables : util.HashMap[Int, AliasTable]): Unit = { var sum : Float = 0.0f - for (k <- 0 to K - 1) { - q_w_proportion_(k) = (k, ((nwk(w)(k) + beta)/nk(k) + beta * V).toFloat) + for (k <- 0 until K) { + q_w_proportion_(k) = (k, ((nwk(w)(k) + beta) / nk(k) + beta * V).toFloat) sum += q_w_proportion_(k)._2 } @@ -475,7 +475,7 @@ object LightLDA { for (i <- nwk.keySet()) { print("w[" + i + "]") val ts = nwk.get(i) - for (j <- 0 to ts.length -1) { + for (j <- ts.indices) { print(" " + nwk(i)(j)) } println @@ -486,11 +486,11 @@ object LightLDA { println("========showing local nwk===========") val nwk = new Array[Array[Int]](m.V) - for (i <- 0 to m.V -1 ) { + for (i <- 0 until m.V) { nwk(i) = new Array[Int](m.K) } - for (i <- 0 to samples.length - 1) { + for (i <- samples.indices) { val doc = samples(i) val words = doc._2 @@ -502,9 +502,9 @@ object LightLDA { } } - for (i <- 0 to m.V -1 ) { + for (i <- 0 until m.V) { print("w[" + i + "]") - for (j <- 0 to m.K -1) { + for (j <- 0 until m.K) { print(" " + nwk(i)(j)) } println @@ -655,7 +655,7 @@ object LightLDA { val nominator = n_td_alpha * n_tw_beta * n_s_beta_sum * proposal_s val denominator = n_sd_alpha * n_sw_beta * n_t_beta_sum * proposal_t - val pi = nominator / denominator; + val pi = nominator / denominator if (rejection < pi) s = t