Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/main/java/com/intel/distml/util/KeyHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ public class KeyHash extends KeyCollection {
public KeyHash(int hashQuato, int hashIndex, long minKey, long maxKey) {
super(KeyCollection.TYPE_HASH);

if (minKey > maxKey) {
throw new IllegalStateException("unexpected key range: (" + minKey + ", " + maxKey + ")");
}

this.hashQuato = hashQuato;
this.hashIndex = hashIndex;
this.minKey = minKey;
Expand Down
22 changes: 15 additions & 7 deletions src/main/java/com/intel/distml/util/KeyRange.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ public class KeyRange extends KeyCollection {

public KeyRange(long f, long l) {
super(KeyCollection.TYPE_RANGE);

if (f > l) {
throw new IllegalStateException("unexpected key range: (" + f + ", " + l + ")");
}

firstKey = f;
lastKey = l;
}
Expand Down Expand Up @@ -66,21 +71,24 @@ public void read(AbstractDataReader in, DataDesc format) throws Exception {
}

public KeyCollection[] linearSplit(int hostNum) {
KeyCollection[] sets = new KeyRange[hostNum];
KeyCollection[] sets = new KeyCollection[hostNum];

long start = firstKey;
long step = (lastKey - firstKey + hostNum) / hostNum;
long keySize = size();
for (int i = 0; i < hostNum; i++) {
long end = Math.min(start + step - 1, lastKey);
sets[i] = new KeyRange(start, end);
start += step;
long start = keySize * i / hostNum;
long end = keySize * (i + 1) / hostNum;
if (end == start) {
sets[i] = EMPTY;
} else {
sets[i] = new KeyRange(firstKey + start, firstKey + end - 1);
}
}

return sets;
}

public KeyCollection[] hashSplit(int hostNum) {
KeyCollection[] sets = new KeyHash[hostNum];
KeyCollection[] sets = new KeyCollection[hostNum];

for (int i = 0; i < hostNum; i++) {
sets[i] = new KeyHash(hostNum, i, firstKey, lastKey);
Expand Down
52 changes: 26 additions & 26 deletions src/main/scala/com/intel/distml/clustering/LightLDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ object LightLDA {
println("dataset size: " + trainDataSize)
dm.setTrainSetSize(trainDataSize)

data.mapPartitionsWithIndex(init(m, monitorPath, batchSize)).persist(StorageLevel.MEMORY_AND_DISK).count()
for (iter <- 0 to p.maxIterations - 1) {
data.mapPartitionsWithIndex(init(m, monitorPath, batchSize)).persist(StorageLevel.MEMORY_AND_DISK).count() //persist necessary?
for (iter <- 0 until p.maxIterations) {
println("================= iteration: " + iter + " =====================")
//data.mapPartitionsWithIndex(verify(p, m, monitorPath)).count()

Expand Down Expand Up @@ -109,7 +109,7 @@ object LightLDA {
var wts = wt.get(wId)
if (wts == null) {
wts = new Array[Integer](m.K)
for (j <- 0 to m.K - 1) {
for (j <- 0 until m.K) {
wts(j) = 0
}
wts(topic) = 1
Expand Down Expand Up @@ -177,7 +177,7 @@ object LightLDA {

val dt = dtm.fetch(KeyCollection.ALL, session)
val dt_old = new util.HashMap[Int, Int]
for (i <- 0 to m.K - 1) {
for (i <- 0 until m.K) {
dt_old.put(i, dt.get(i))
}

Expand Down Expand Up @@ -263,22 +263,22 @@ object LightLDA {
val wtm = m.getMatrix("word-topics").asInstanceOf[IntMatrixWithIntKey]

val dt = dtm.fetch(KeyCollection.ALL, session)
for (i <- 0 to m.K -1 ) {
for (i <- 0 until m.K) {
println("dt(" + i + ") = " + dt(i))
}
val wt : java.util.HashMap[Integer, Array[Integer]] = wtm.fetch(KeyCollection.ALL, session)
for (w <- 0 to m.V - 1) {
for (i <- 0 to m.K - 1) {
for (w <- 0 until m.V) {
for (i <- 0 until m.K) {
println("wt(" + w + ")(" + i + ") = " + wt(w)(i))
}
}

val t_dt = new Array[Int](m.K)
val t_wt = new Array[Array[Int]](m.V)
for (i <- 0 to m.V -1)
for (i <- 0 until m.V)
t_wt(i) = new Array[Int](m.K)

while(it.hasNext) {
while (it.hasNext) {
val (ndk, words) = it.next()

val t_ndk = new Array[Int](m.K)
Expand All @@ -292,19 +292,19 @@ object LightLDA {
t_wt(w)(t) += 1
}

for (i <- 0 to m.K - 1) {
for (i <- 0 until m.K) {
if (ndk(i) != t_ndk(i))
throw new IllegalStateException("verify failed, ndk(" + i + "): " + ndk(i) + ", " + t_ndk(i))
}
}

for (i <- 0 to m.K -1 ) {
for (i <- 0 until m.K) {
if (dt(i) != t_dt(i))
throw new IllegalStateException("verify failed, dt(" + i + "): " + dt(i) + ", " + t_dt(i))
}

for (w <- 0 to m.V - 1) {
for (i <- 0 to m.K - 1) {
for (w <- 0 until m.V) {
for (i <- 0 until m.K) {
if (wt(w)(i) != t_wt(w)(i))
throw new IllegalStateException("verify failed, wt(" + w + ")(" + i + "): " + wt(i) + ", " + t_wt(i))
}
Expand Down Expand Up @@ -351,22 +351,22 @@ object LightLDA {

val prob = new Array[Double](m.K)
for (iter <- 0 to 50) {
for (i <- 0 to samples.size() -1 ) {
for (i <- samples.indices) {
val (ndk, words) = samples(i)

for (n <- words.indices) {
val w: Int = words(n)._1
val old: Int = words(n)._2

for (k <- 0 to m.K-1) {
for (k <- 0 until m.K) {
val theta: Double = (ndk(k) + m.alpha) / (words.length + m.alpha_sum)
val phi: Double = (wt(w)(k) + m.beta) / (dt(k) + m.beta_sum)

prob(k) = theta * phi
}

for (k <- 1 to m.K - 1) {
prob(k) += prob(k-1)
for (k <- 1 until m.K) {
prob(k) += prob(k - 1)
}

val u = Math.random() * prob(m.K - 1)
Expand Down Expand Up @@ -431,7 +431,7 @@ object LightLDA {
var l: Double = 0
val w: Int = words(n)._1

for (k <- 0 to m.K-1) {
for (k <- 0 until m.K) {
val theta: Double = (ndk(k) + m.alpha) / (words.length + m.alpha_sum)
val phi: Double = (nwk(w)(k) + m.beta) / (nk(k) + m.beta_sum)
l += theta * phi
Expand All @@ -451,8 +451,8 @@ object LightLDA {
q_w_proportion_ : Array[(Int, Float)], tables : util.HashMap[Int, AliasTable]): Unit = {

var sum : Float = 0.0f
for (k <- 0 to K - 1) {
q_w_proportion_(k) = (k, ((nwk(w)(k) + beta)/nk(k) + beta * V).toFloat)
for (k <- 0 until K) {
q_w_proportion_(k) = (k, ((nwk(w)(k) + beta) / nk(k) + beta * V).toFloat)
sum += q_w_proportion_(k)._2
}

Expand All @@ -475,7 +475,7 @@ object LightLDA {
for (i <- nwk.keySet()) {
print("w[" + i + "]")
val ts = nwk.get(i)
for (j <- 0 to ts.length -1) {
for (j <- ts.indices) {
print(" " + nwk(i)(j))
}
println
Expand All @@ -486,11 +486,11 @@ object LightLDA {
println("========showing local nwk===========")

val nwk = new Array[Array[Int]](m.V)
for (i <- 0 to m.V -1 ) {
for (i <- 0 until m.V) {
nwk(i) = new Array[Int](m.K)
}

for (i <- 0 to samples.length - 1) {
for (i <- samples.indices) {
val doc = samples(i)

val words = doc._2
Expand All @@ -502,9 +502,9 @@ object LightLDA {
}
}

for (i <- 0 to m.V -1 ) {
for (i <- 0 until m.V) {
print("w[" + i + "]")
for (j <- 0 to m.K -1) {
for (j <- 0 until m.K) {
print(" " + nwk(i)(j))
}
println
Expand Down Expand Up @@ -655,7 +655,7 @@ object LightLDA {
val nominator = n_td_alpha * n_tw_beta * n_s_beta_sum * proposal_s
val denominator = n_sd_alpha * n_sw_beta * n_t_beta_sum * proposal_t

val pi = nominator / denominator;
val pi = nominator / denominator

if (rejection < pi)
s = t
Expand Down