diff --git a/tensorflow_recommenders/layers/factorized_top_k.py b/tensorflow_recommenders/layers/factorized_top_k.py index b6a7f451..0316c3cf 100644 --- a/tensorflow_recommenders/layers/factorized_top_k.py +++ b/tensorflow_recommenders/layers/factorized_top_k.py @@ -466,7 +466,7 @@ def top_k(state: Tuple[tf.Tensor, tf.Tensor], def enumerate_rows(batch: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: """Enumerates rows in each batch using a total element counter.""" - starting_counter = self._counter.read_value() + starting_counter = self._counter.value end_counter = self._counter.assign_add(tf.shape(batch)[0]) return tf.range(starting_counter, end_counter), batch diff --git a/tensorflow_recommenders/metrics/factorized_top_k.py b/tensorflow_recommenders/metrics/factorized_top_k.py index b9bfd7bd..8d025a9e 100644 --- a/tensorflow_recommenders/metrics/factorized_top_k.py +++ b/tensorflow_recommenders/metrics/factorized_top_k.py @@ -177,7 +177,8 @@ def update_state( tf.reduce_sum(ids_match[:, :k], axis=1, keepdims=True), 0.0, 1.0 ) - update_ops.append(metric.update_state(match_found, sample_weight)) + metric.update_state(match_found, sample_weight) + update_ops.append(metric.result()) else: # Score-based evaluation. y_pred = tf.concat([positive_scores, top_k_predictions], axis=1) @@ -189,7 +190,8 @@ def update_state( predictions=y_pred, k=k ) - update_ops.append(metric.update_state(top_k_accuracy, sample_weight)) + metric.update_state(top_k_accuracy, sample_weight) + update_ops.append(metric.result()) return tf.group(update_ops) \ No newline at end of file