diff --git a/ffn/training/examples.py b/ffn/training/examples.py index f10e615..13e6c2a 100644 --- a/ffn/training/examples.py +++ b/ffn/training/examples.py @@ -70,7 +70,8 @@ def get_example(load_example, eval_tracker: tracker.EvalTracker, assert predicted.base is seed yield predicted, patches, labels, weights - eval_tracker.add_patch(full_labels, seed, loss_weights, coord) + eval_tracker.add_patch(full_labels, seed, loss_weights, coord, + volume_name=volname) ExampleGenerator = Iterable[tuple[np.ndarray, np.ndarray, np.ndarray, diff --git a/ffn/training/tracker.py b/ffn/training/tracker.py index 2c1377c..480a3c9 100644 --- a/ffn/training/tracker.py +++ b/ffn/training/tracker.py @@ -17,16 +17,17 @@ import collections import enum import io -from typing import Optional, Sequence +from typing import Any, Sequence +from absl import logging import numpy as np - import PIL import PIL.Image import PIL.ImageDraw +import PIL.ImageFont from scipy import special - import tensorflow.compat.v1 as tf + from . import mask from . import variables @@ -62,20 +63,26 @@ class FovStat(enum.IntEnum): class EvalTracker: """Tracks eval results over multiple training steps.""" - def __init__(self, - eval_shape: list[int], - shifts: Sequence[tuple[int, int, int]]): + def __init__( + self, eval_shape: list[int], shifts: Sequence[tuple[int, int, int]] + ): # TODO(mjanusz): Remove this TFv1 code once no longer used. if not tf.executing_eagerly(): self.eval_labels = tf.compat.v1.placeholder( - tf.float32, [1] + eval_shape + [1], name='eval_labels') + tf.float32, [1] + eval_shape + [1], name='eval_labels' + ) self.eval_preds = tf.compat.v1.placeholder( - tf.float32, [1] + eval_shape + [1], name='eval_preds') + tf.float32, [1] + eval_shape + [1], name='eval_preds' + ) self.eval_weights = tf.compat.v1.placeholder( - tf.float32, [1] + eval_shape + [1], name='eval_weights') + tf.float32, [1] + eval_shape + [1], name='eval_weights' + ) self.eval_loss = tf.reduce_mean( - self.eval_weights * tf.nn.sigmoid_cross_entropy_with_logits( - logits=self.eval_preds, labels=self.eval_labels)) + self.eval_weights + * tf.nn.sigmoid_cross_entropy_with_logits( + logits=self.eval_preds, labels=self.eval_labels + ) + ) self.sess = None self.eval_threshold = special.logit(0.9) self._eval_shape = eval_shape # zyx @@ -138,12 +145,15 @@ def track_weights(self, weights: np.ndarray): self.fov_stats.value[FovStat.MASKED_VOXELS] += np.sum(weights == 0.0) self.fov_stats.value[FovStat.WEIGHTS_SUM] += np.sum(weights) - def record_move(self, wanted: bool, executed: bool, - offset_xyz: Sequence[int]): + def record_move( + self, wanted: bool, executed: bool, offset_xyz: Sequence[int] + ): """Records an FFN FOV move.""" r = int(np.linalg.norm(offset_xyz)) - assert r in self.moves_by_r, ('%d not in %r' % - (r, list(self.moves_by_r.keys()))) + assert r in self.moves_by_r, '%d not in %r' % ( + r, + list(self.moves_by_r.keys()), + ) if wanted: if executed: @@ -156,9 +166,15 @@ def record_move(self, wanted: bool, executed: bool, self.moves.value[MoveType.SPURIOUS] += 1 self.moves_by_r[r].value[MoveType.SPURIOUS] += 1 - def slice_image(self, coord: np.ndarray, labels: np.ndarray, - predicted: np.ndarray, weights: np.ndarray, - slice_axis: int) -> tf.Summary.Value: + def slice_image( + self, + coord: np.ndarray, + labels: np.ndarray, + predicted: np.ndarray, + weights: np.ndarray, + slice_axis: int, + volume_name: str | bytes | Sequence[Any] | np.ndarray | None = None, + ) -> tf.Summary.Value: """Builds a tf.Summary showing a slice of an object mask. The object mask slice is shown side by side with the corresponding @@ -172,6 +188,7 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray, slice_axis: axis in the middle of which to place the cutting plane for which the summary image will be generated, valid values are 2 ('x'), 1 ('y'), and 0 ('z'). + volume_name: name of the volume to be displayed on the image. Returns: tf.Summary.Value object with the image. @@ -191,14 +208,37 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray, im = PIL.Image.fromarray( np.repeat( - np.concatenate([labels, predicted, weights], axis=1)[..., - np.newaxis], + np.concatenate([labels, predicted, weights], axis=1)[ + ..., np.newaxis + ], 3, - axis=2), 'RGB') + axis=2, + ), + 'RGB', + ) draw = PIL.ImageDraw.Draw(im) x, y, z = coord.squeeze() - draw.text((1, 1), '%d %d %d' % (x, y, z), fill='rgb(255,64,64)') + text = f'{x},{y},{z}' + if volume_name is not None: + if ( + isinstance(volume_name, (list, tuple, np.ndarray)) + and len(volume_name) == 1 + ): + volume_name = volume_name[0] + + if isinstance(volume_name, bytes): + volume_name = volume_name.decode('utf-8') + + text += f'\n{volume_name}' + + try: + + # font = PIL.ImageFont.load_default() + except (IOError, ValueError): + font = PIL.ImageFont.load_default() + + draw.text((1, 1), text, fill='rgb(255,64,64)', font=font) del draw im.save(buf, 'PNG') @@ -212,14 +252,19 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray, height=h, width=w * 3, colorspace=3, # RGB - encoded_image_string=buf.getvalue())) - - def add_patch(self, - labels: np.ndarray, - predicted: np.ndarray, - weights: np.ndarray, - coord: Optional[np.ndarray] = None, - image_summaries: bool = True): + encoded_image_string=buf.getvalue(), + ), + ) + + def add_patch( + self, + labels: np.ndarray, + predicted: np.ndarray, + weights: np.ndarray, + coord: np.ndarray | None = None, + image_summaries: bool = True, + volume_name: str | None = None, + ): """Evaluates single-object segmentation quality.""" predicted = mask.crop_and_pad(predicted, (0, 0, 0), self._eval_shape) @@ -228,15 +273,21 @@ def add_patch(self, if not tf.executing_eagerly(): assert self.sess is not None - loss, = self.sess.run( - [self.eval_loss], { + (loss,) = self.sess.run( + [self.eval_loss], + { self.eval_labels: labels, self.eval_preds: predicted, - self.eval_weights: weights - }) + self.eval_weights: weights, + }, + ) else: - loss = tf.reduce_mean(weights * tf.nn.sigmoid_cross_entropy_with_logits( - logits=predicted, labels=labels)) + loss = tf.reduce_mean( + weights + * tf.nn.sigmoid_cross_entropy_with_logits( + logits=predicted, labels=labels + ) + ) self.loss.value[:] += loss self.num_voxels.value[VoxelType.TOTAL] += labels.size @@ -247,23 +298,29 @@ def add_patch(self, pred_bg = np.logical_not(pred_mask) true_bg = np.logical_not(true_mask) - self.prediction_counts.value[PredictionType.TP] += np.sum(pred_mask - & true_mask) + self.prediction_counts.value[PredictionType.TP] += np.sum( + pred_mask & true_mask + ) self.prediction_counts.value[PredictionType.TN] += np.sum(pred_bg & true_bg) - self.prediction_counts.value[PredictionType.FP] += np.sum(pred_mask - & true_bg) - self.prediction_counts.value[PredictionType.FN] += np.sum(pred_bg - & true_mask) + self.prediction_counts.value[PredictionType.FP] += np.sum( + pred_mask & true_bg + ) + self.prediction_counts.value[PredictionType.FN] += np.sum( + pred_bg & true_mask + ) self.num_patches.value[:] += 1 if image_summaries: predicted = special.expit(predicted) self.images_xy.append( - self.slice_image(coord, labels, predicted, weights, 0)) + self.slice_image(coord, labels, predicted, weights, 0, volume_name) + ) self.images_xz.append( - self.slice_image(coord, labels, predicted, weights, 1)) + self.slice_image(coord, labels, predicted, weights, 1, volume_name) + ) self.images_yz.append( - self.slice_image(coord, labels, predicted, weights, 2)) + self.slice_image(coord, labels, predicted, weights, 2, volume_name) + ) def _compute_classification_metrics(self, prediction_counts, prefix): """Computes standard classification metrics.""" @@ -276,19 +333,21 @@ def _compute_classification_metrics(self, prediction_counts, prefix): recall = tp / max(tp + fn, 1) if precision > 0 or recall > 0: - f1 = (2.0 * precision * recall / (precision + recall)) + f1 = 2.0 * precision * recall / (precision + recall) else: f1 = 0.0 return [ tf.Summary.Value( tag='%s/accuracy' % prefix, - simple_value=(tp + tn) / max(tp + tn + fp + fn, 1)), + simple_value=(tp + tn) / max(tp + tn + fp + fn, 1), + ), tf.Summary.Value(tag='%s/precision' % prefix, simple_value=precision), tf.Summary.Value(tag='%s/recall' % prefix, simple_value=recall), tf.Summary.Value( - tag='%s/specificity' % prefix, simple_value=tn / max(tn + fp, 1)), - tf.Summary.Value(tag='%s/f1' % prefix, simple_value=f1) + tag='%s/specificity' % prefix, simple_value=tn / max(tn + fp, 1) + ), + tf.Summary.Value(tag='%s/f1' % prefix, simple_value=f1), ] def get_summaries(self) -> list[tf.Summary.Value]: @@ -308,49 +367,74 @@ def get_summaries(self) -> list[tf.Summary.Value]: move_summaries.append( tf.Summary.Value( tag='moves/all/%s' % mt.name.lower(), - simple_value=self.moves.tf_value[mt] / total_moves)) - - summaries = [ - tf.Summary.Value( - tag='fov/masked_voxel_fraction', - simple_value=(self.fov_stats.tf_value[FovStat.MASKED_VOXELS] / - self.fov_stats.tf_value[FovStat.TOTAL_VOXELS])), - tf.Summary.Value( - tag='fov/average_weight', - simple_value=(self.fov_stats.tf_value[FovStat.WEIGHTS_SUM] / - self.fov_stats.tf_value[FovStat.TOTAL_VOXELS])), - tf.Summary.Value( - tag='masked_voxel_fraction', - simple_value=(self.num_voxels.tf_value[VoxelType.MASKED] / - self.num_voxels.tf_value[VoxelType.TOTAL])), - tf.Summary.Value( - tag='eval/patch_loss', - simple_value=self.loss.tf_value[0] / self.num_patches.tf_value[0]), - tf.Summary.Value( - tag='eval/patches', simple_value=self.num_patches.tf_value[0]), - tf.Summary.Value(tag='moves/total', simple_value=total_moves) - ] + move_summaries + ( - list(self.meshes) + list(self.images_xy) + list(self.images_xz) + - list(self.images_yz)) + simple_value=self.moves.tf_value[mt] / total_moves, + ) + ) + + summaries = ( + [ + tf.Summary.Value( + tag='fov/masked_voxel_fraction', + simple_value=( + self.fov_stats.tf_value[FovStat.MASKED_VOXELS] + / self.fov_stats.tf_value[FovStat.TOTAL_VOXELS] + ), + ), + tf.Summary.Value( + tag='fov/average_weight', + simple_value=( + self.fov_stats.tf_value[FovStat.WEIGHTS_SUM] + / self.fov_stats.tf_value[FovStat.TOTAL_VOXELS] + ), + ), + tf.Summary.Value( + tag='masked_voxel_fraction', + simple_value=( + self.num_voxels.tf_value[VoxelType.MASKED] + / self.num_voxels.tf_value[VoxelType.TOTAL] + ), + ), + tf.Summary.Value( + tag='eval/patch_loss', + simple_value=self.loss.tf_value[0] + / self.num_patches.tf_value[0], + ), + tf.Summary.Value( + tag='eval/patches', simple_value=self.num_patches.tf_value[0] + ), + tf.Summary.Value(tag='moves/total', simple_value=total_moves), + ] + + move_summaries + + ( + list(self.meshes) + + list(self.images_xy) + + list(self.images_xz) + + list(self.images_yz) + ) + ) summaries.extend( - self._compute_classification_metrics(self.prediction_counts, - 'eval/all')) + self._compute_classification_metrics(self.prediction_counts, 'eval/all') + ) for r, r_moves in self.moves_by_r.items(): total_moves = sum(r_moves.tf_value) summaries.extend([ tf.Summary.Value( tag='moves/r=%d/correct' % r, - simple_value=r_moves.tf_value[MoveType.CORRECT] / total_moves), + simple_value=r_moves.tf_value[MoveType.CORRECT] / total_moves, + ), tf.Summary.Value( tag='moves/r=%d/spurious' % r, - simple_value=r_moves.tf_value[MoveType.SPURIOUS] / total_moves), + simple_value=r_moves.tf_value[MoveType.SPURIOUS] / total_moves, + ), tf.Summary.Value( tag='moves/r=%d/missed' % r, - simple_value=r_moves.tf_value[MoveType.MISSED] / total_moves), + simple_value=r_moves.tf_value[MoveType.MISSED] / total_moves, + ), tf.Summary.Value( - tag='moves/r=%d/total' % r, simple_value=total_moves) + tag='moves/r=%d/total' % r, simple_value=total_moves + ), ]) return summaries diff --git a/ffn/training/tracker_test.py b/ffn/training/tracker_test.py new file mode 100644 index 0000000..a279e3d --- /dev/null +++ b/ffn/training/tracker_test.py @@ -0,0 +1,88 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from absl.testing import absltest +from ffn.training import tracker +import numpy as np +import tensorflow.compat.v1 as tf + +tf.disable_eager_execution() + + +class TrackerTest(absltest.TestCase): + + def setUp(self): + super().setUp() + tf.reset_default_graph() + + def test_tracker_rendering(self): + eval_shape = [32, 32, 32] + shifts = [(0, 0, 0)] + eval_tracker = tracker.EvalTracker(eval_shape, shifts) + eval_tracker.sess = tf.Session() + eval_tracker.sess.run(tf.global_variables_initializer()) + + labels = np.zeros([1] + eval_shape + [1], dtype=np.float32) + predicted = np.zeros([1] + eval_shape + [1], dtype=np.float32) + weights = np.zeros([1] + eval_shape + [1], dtype=np.float32) + + # Check rendering with volume name + eval_tracker.add_patch( + labels, + predicted, + weights, + coord=np.array([0, 0, 0]), + volume_name='test_volume', + ) + + eval_tracker.to_tf() + summaries = eval_tracker.get_summaries() + + # Verify that we got image summaries + image_summaries = [s for s in summaries if s.HasField('image')] + self.assertNotEmpty(image_summaries) + + # Check specifically for the tags we expect + tags = [s.tag for s in image_summaries] + self.assertIn('final_xy/0', tags) + self.assertIn('final_xz/0', tags) + self.assertIn('final_yz/0', tags) + + def test_tracker_rendering_no_volume_name(self): + eval_shape = [32, 32, 32] + shifts = [(0, 0, 0)] + eval_tracker = tracker.EvalTracker(eval_shape, shifts) + eval_tracker.sess = tf.Session() + eval_tracker.sess.run(tf.global_variables_initializer()) + + labels = np.zeros([1] + eval_shape + [1], dtype=np.float32) + predicted = np.zeros([1] + eval_shape + [1], dtype=np.float32) + weights = np.zeros([1] + eval_shape + [1], dtype=np.float32) + + # Check rendering without volume name + eval_tracker.add_patch( + labels, predicted, weights, coord=np.array([0, 0, 0]) + ) + + eval_tracker.to_tf() + summaries = eval_tracker.get_summaries() + + # Verify that we got image summaries + image_summaries = [s for s in summaries if s.HasField('image')] + self.assertNotEmpty(image_summaries) + + +if __name__ == '__main__': + absltest.main()