diff --git a/README.md b/README.md index c24668c1..b6221b84 100644 --- a/README.md +++ b/README.md @@ -1,98 +1,82 @@ - +


-# MeshCNN in PyTorch +# MedMeshCNN -### SIGGRAPH 2019 [[Paper]](https://bit.ly/meshcnn) [[Project Page]](https://ranahanocka.github.io/MeshCNN/)
+MedMeshCNN is an expansion of [MeshCNN](https://ranahanocka.github.io/MeshCNN/) proposed by [Rana Hanocka](https://www.cs.tau.ac.il/~hanocka/) et al. -MeshCNN is a general-purpose deep neural network for 3D triangular meshes, which can be used for tasks such as 3D shape classification or segmentation. This framework includes convolution, pooling and unpooling layers which are applied directly on the mesh edges. +[MeshCNN](https://ranahanocka.github.io/MeshCNN/) is a general-purpose deep neural network for 3D triangular meshes, which can be used for tasks such as 3D shape classification or segmentation. This framework includes convolution, pooling and unpooling layers which are applied directly on the mesh edges. -
+MedMeshCNN enables the use of [MeshCNN](https://ranahanocka.github.io/MeshCNN/) for medical surface meshes through an improved memory efficiency that allows to +to keep patient-specific properties and fine-grained patterns during segmentation. Furthermore, a weighted loss function improves the performance of MedMeshCNN on imbalanced datasets that are often caused by pathological appearances. + +MedMeshCNN may also be used beyond the medical domain for all applications that include imbalanced datasets and require fine-grained segmentation results. + +Advances of MedMeshCNN include: +* Processing meshes with 170.000 edges (NVIDIA GeForce GTX 1080 TiGPU with 12GB RAM) +* IoU metrics +* Weighted loss function to enable better performances on imbalanced class distributions + +Please check out the corresponding [PartSegmentationToolbox](https://github.com/LSnyd/PartSegmentationToolbox) to find further information on how to create a segmentation ground truth and helper scripts to scale segmentation results to different mesh resultions. -The code was written by [Rana Hanocka](https://www.cs.tau.ac.il/~hanocka/) and [Amir Hertz](http://pxcm.org/) with support from [Noa Fish](http://www.cs.tau.ac.il/~noafish/). # Getting Started + ### Installation - Clone this repo: ```bash -git clone https://github.com/ranahanocka/MeshCNN.git -cd MeshCNN +git clone https://github.com/LSnyd/MedMeshCNN.git +cd MedMeshCNN ``` -- Install dependencies: [PyTorch](https://pytorch.org/) version 1.2. Optional : [tensorboardX](https://github.com/lanpa/tensorboardX) for training plots. +- Install dependencies: [PyTorch](https://pytorch.org/) version 1.4. Optional : [tensorboardX](https://github.com/lanpa/tensorboardX) for training plots. - Via new conda environment `conda env create -f environment.yml` (creates an environment called meshcnn) - -### 3D Shape Classification on SHREC + + +### 3D Shape Segmentation on Humans Download the dataset ```bash -bash ./scripts/shrec/get_data.sh +bash ./scripts/human_seg/get_data.sh ``` Run training (if using conda env first activate env e.g. ```source activate meshcnn```) ```bash -bash ./scripts/shrec/train.sh +bash ./scripts/human_seg/train.sh ``` To view the training loss plots, in another terminal run ```tensorboard --logdir runs``` and click [http://localhost:6006](http://localhost:6006). Run test and export the intermediate pooled meshes: ```bash -bash ./scripts/shrec/test.sh +bash /scripts/human_seg/test.sh ``` Visualize the network-learned edge collapses: ```bash -bash ./scripts/shrec/view.sh +bash ./scripts/human_seg/view.sh ``` -An example of collapses for a mesh: +Some segmentation result examples: - + -Note, you can also get pre-trained weights using bash ```./scripts/shrec/get_pretrained.sh```. +### Hyperparameters -In order to use the pre-trained weights, run ```train.sh``` which will compute and save the mean / standard deviation of the training data. +To alter the values of the hyperparameters, change the bash scripts above accordingly. +This also includes the weight vector for the weighted loss function, which requires one weight per class. -### 3D Shape Segmentation on Humans -The same as above, to download the dataset / run train / get pretrained / run test / view -```bash -bash ./scripts/human_seg/get_data.sh -bash ./scripts/human_seg/train.sh -bash ./scripts/human_seg/get_pretrained.sh -bash ./scripts/human_seg/test.sh -bash ./scripts/human_seg/view.sh -``` -Some segmentation result examples: +### More Info - -### Additional Datasets -The same scripts also exist for COSEG segmentation in ```scripts/coseg_seg``` and cubes classification in ```scripts/cubes```. +Check out the corresponding [PartSegmentationToolbox](https://github.com/LSnyd/PartSegmentationToolbox) and my [medium article](https://medium.com/@lisa_81193/how-to-perform-a-3d-segmentation-in-blender-2-82-d87300305f3f) to find further information on how to create a segmentation ground truth as illustrated below. You can also find helper scripts that scale segmentation results to different mesh resolutions. -# More Info -Check out the [MeshCNN wiki](https://github.com/ranahanocka/MeshCNN/wiki) for more details. Specifically, see info on [segmentation](https://github.com/ranahanocka/MeshCNN/wiki/Segmentation) and [data processing](https://github.com/ranahanocka/MeshCNN/wiki/Data-Processing). + -# Citation -If you find this code useful, please consider citing our paper -``` -@article{hanocka2019meshcnn, - title={MeshCNN: A Network with an Edge}, - author={Hanocka, Rana and Hertz, Amir and Fish, Noa and Giryes, Raja and Fleishman, Shachar and Cohen-Or, Daniel}, - journal={ACM Transactions on Graphics (TOG)}, - volume={38}, - number={4}, - pages = {90:1--90:12}, - year={2019}, - publisher={ACM} -} -``` +Also check out the [MeshCNN wiki](https://github.com/ranahanocka/MeshCNN/wiki) for more details. Specifically, see info on [segmentation](https://github.com/ranahanocka/MeshCNN/wiki/Segmentation) and [data processing](https://github.com/ranahanocka/MeshCNN/wiki/Data-Processing). # Questions / Issues -If you have questions or issues running this code, please open an issue so we can know to fix it. - -# Acknowledgments -This code design was adopted from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). +If you have questions or issues running this code, please open an issue. \ No newline at end of file diff --git a/docs/imgs/C060_seg_fine.png b/docs/imgs/C060_seg_fine.png new file mode 100644 index 00000000..2b8c25bf Binary files /dev/null and b/docs/imgs/C060_seg_fine.png differ diff --git a/docs/imgs/C60.png b/docs/imgs/C60.png new file mode 100644 index 00000000..dec257fd Binary files /dev/null and b/docs/imgs/C60.png differ diff --git a/docs/imgs/T18.png b/docs/imgs/T18.png deleted file mode 100644 index 1f29c8bd..00000000 Binary files a/docs/imgs/T18.png and /dev/null differ diff --git a/docs/imgs/T252.png b/docs/imgs/T252.png deleted file mode 100644 index b42727c1..00000000 Binary files a/docs/imgs/T252.png and /dev/null differ diff --git a/docs/imgs/T76.png b/docs/imgs/T76.png deleted file mode 100644 index a6c3ae19..00000000 Binary files a/docs/imgs/T76.png and /dev/null differ diff --git a/docs/imgs/alien.gif b/docs/imgs/alien.gif deleted file mode 100644 index 4e04c641..00000000 Binary files a/docs/imgs/alien.gif and /dev/null differ diff --git a/docs/imgs/coseg_alien.png b/docs/imgs/coseg_alien.png deleted file mode 100644 index a8ab24a3..00000000 Binary files a/docs/imgs/coseg_alien.png and /dev/null differ diff --git a/docs/imgs/coseg_chair.png b/docs/imgs/coseg_chair.png deleted file mode 100644 index 1560b131..00000000 Binary files a/docs/imgs/coseg_chair.png and /dev/null differ diff --git a/docs/imgs/coseg_vase.png b/docs/imgs/coseg_vase.png deleted file mode 100644 index 64eca95c..00000000 Binary files a/docs/imgs/coseg_vase.png and /dev/null differ diff --git a/docs/imgs/cubes.png b/docs/imgs/cubes.png deleted file mode 100644 index 7b72468f..00000000 Binary files a/docs/imgs/cubes.png and /dev/null differ diff --git a/docs/imgs/cubes2.png b/docs/imgs/cubes2.png deleted file mode 100644 index e323976e..00000000 Binary files a/docs/imgs/cubes2.png and /dev/null differ diff --git a/docs/imgs/input_edge_features.png b/docs/imgs/input_edge_features.png deleted file mode 100644 index 649e58f4..00000000 Binary files a/docs/imgs/input_edge_features.png and /dev/null differ diff --git a/docs/imgs/mesh_conv.png b/docs/imgs/mesh_conv.png deleted file mode 100644 index 37c6201d..00000000 Binary files a/docs/imgs/mesh_conv.png and /dev/null differ diff --git a/docs/imgs/mesh_pool_unpool.png b/docs/imgs/mesh_pool_unpool.png deleted file mode 100644 index e9c136d6..00000000 Binary files a/docs/imgs/mesh_pool_unpool.png and /dev/null differ diff --git a/environment.yml b/environment.yml index 50f52cf5..b33426a3 100644 --- a/environment.yml +++ b/environment.yml @@ -3,12 +3,12 @@ channels: - pytorch - defaults dependencies: - - python=3.6.8 + - python=3.8.5 - cython=0.27.3 - - pytorch=1.2.0 - - numpy=1.15.0 - - matplotlib=3.0.3 - - pip + - pytorch=1.4.0 + - numpy=1.19.1 + - matplotlib=3.3.1 + - pip - pip: - git+https://github.com/lanpa/tensorboardX.git - pytest==5.1.1 diff --git a/models/layers/mesh_union.py b/models/layers/mesh_union.py index 325f29bf..106cb7cf 100644 --- a/models/layers/mesh_union.py +++ b/models/layers/mesh_union.py @@ -1,15 +1,33 @@ import torch from torch.nn import ConstantPad2d +import time +from util.util import myindexrowselect +from options.base_options import BaseOptions class MeshUnion: - def __init__(self, n, device=torch.device('cpu')): + def __init__(self, n, device=torch.device('cpu')): + gpu_ids = BaseOptions().get_device() + self.device = torch.device('cuda:{}'.format(gpu_ids[0])) if len(gpu_ids)>0 else torch.device('cpu') + self.__size = n self.rebuild_features = self.rebuild_features_average - self.groups = torch.eye(n, device=device) + self.values = torch.ones(n, dtype= torch.float) + self.groups = torch.sparse_coo_tensor(indices= torch.stack((torch.arange(n), torch.arange(n)),dim=0), values= self.values, + + size=(self.__size, self.__size), device=self.device) + def union(self, source, target): - self.groups[target, :] += self.groups[source, :] + index = torch.tensor([source], dtype=torch.long) + row = myindexrowselect(self.groups, index, self.device).to(self.device) + row._indices()[0] = torch.tensor(target) + row = torch.sparse_coo_tensor(indices=row._indices(), values= row._values(), + size=(self.__size, self.__size), device=self.device) + self.groups = self.groups.add(row) + self.groups = self.groups.coalesce() + del index, row + def remove_group(self, index): return @@ -18,16 +36,21 @@ def get_group(self, edge_key): return self.groups[edge_key, :] def get_occurrences(self): - return torch.sum(self.groups, 0) + return torch.sparse.sum(self.groups, 0).values() + def get_groups(self, tensor_mask): - self.groups = torch.clamp(self.groups, 0, 1) - return self.groups[tensor_mask, :] + ## Max comp + mask_index = torch.squeeze((tensor_mask == True).nonzero()).to(self.device) + return myindexrowselect(self.groups, mask_index, self.device) + def rebuild_features_average(self, features, mask, target_edges): self.prepare_groups(features, mask) - fe = torch.matmul(features.squeeze(-1), self.groups) - occurrences = torch.sum(self.groups, 0).expand(fe.shape) + + self.groups = self.groups.to(self.device) + fe = torch.matmul(self.groups.transpose(0,1),features.squeeze(-1).transpose(1,0)).transpose(0,1) + occurrences = torch.sparse.sum(self.groups, 0).to_dense() fe = fe / occurrences padding_b = target_edges - fe.shape[1] if padding_b > 0: @@ -35,10 +58,16 @@ def rebuild_features_average(self, features, mask, target_edges): fe = padding_b(fe) return fe + def prepare_groups(self, features, mask): - tensor_mask = torch.from_numpy(mask) - self.groups = torch.clamp(self.groups[tensor_mask, :], 0, 1).transpose_(1, 0) + mask_index = torch.squeeze((torch.from_numpy(mask) == True).nonzero()) + + self.groups = myindexrowselect(self.groups, mask_index, self.device).transpose(1,0) padding_a = features.shape[1] - self.groups.shape[0] + if padding_a > 0: - padding_a = ConstantPad2d((0, 0, 0, padding_a), 0) - self.groups = padding_a(self.groups) + self.groups = torch.sparse_coo_tensor( + indices=self.groups._indices(), values=self.groups._values(), dtype=torch.float32, + size=(features.shape[1], self.groups.shape[1])) + + diff --git a/models/layers/mesh_unpool.py b/models/layers/mesh_unpool.py index 63a13e94..3250fe1b 100644 --- a/models/layers/mesh_unpool.py +++ b/models/layers/mesh_unpool.py @@ -1,12 +1,13 @@ import torch import torch.nn as nn - - +from options.base_options import BaseOptions class MeshUnpool(nn.Module): def __init__(self, unroll_target): super(MeshUnpool, self).__init__() self.unroll_target = unroll_target + gpu_ids = BaseOptions().get_device() + self.device = torch.device('cuda:{}'.format(gpu_ids[0])) if len(gpu_ids)>0 else torch.device('cpu') def __call__(self, features, meshes): return self.forward(features, meshes) @@ -16,8 +17,11 @@ def pad_groups(self, group, unroll_start): padding_rows = unroll_start - start padding_cols = self.unroll_target - end if padding_rows != 0 or padding_cols !=0: - padding = nn.ConstantPad2d((0, padding_cols, 0, padding_rows), 0) - group = padding(group) + size1 = group.shape[0] + padding_rows + size2 = group.shape[1] + padding_cols + group = torch.sparse_coo_tensor( + indices=group._indices(), values=group._values(), dtype=torch.float32, + size=(size1, size2)) return group def pad_occurrences(self, occurrences): @@ -29,13 +33,44 @@ def pad_occurrences(self, occurrences): def forward(self, features, meshes): batch_size, nf, edges = features.shape - groups = [self.pad_groups(mesh.get_groups(), edges) for mesh in meshes] - unroll_mat = torch.cat(groups, dim=0).view(batch_size, edges, -1) + groups = [self.pad_groups(mesh.get_groups(), edges).to(self.device) for mesh in meshes] + unroll_mat = torch.stack(groups) occurrences = [self.pad_occurrences(mesh.get_occurrences()) for mesh in meshes] - occurrences = torch.cat(occurrences, dim=0).view(batch_size, 1, -1) - occurrences = occurrences.expand(unroll_mat.shape) - unroll_mat = unroll_mat / occurrences - unroll_mat = unroll_mat.to(features.device) + occurrences = torch.unsqueeze(torch.stack(occurrences, dim=0), dim=1) + + #Sparse division only possible for scalars + #Iterate over dense batches + + imin = 0 + length = 500 + result = [] + while imin <= unroll_mat.size()[2]: + try: + sliceUnroll_mat = unroll_mat.narrow_copy(2, imin, length).to_dense().to(self.device) + sliceOcc = occurrences.narrow_copy(2, imin, length).to(self.device) + sliceResult = (sliceUnroll_mat / sliceOcc).to_sparse() + imin = imin + 500 + + result.append(sliceResult) + + except Exception: + length = unroll_mat.size()[2] - imin + + unroll_mat = torch.cat(result, -1).to(features.device) + for mesh in meshes: mesh.unroll_gemm() - return torch.matmul(features, unroll_mat) + + #Fix Matmul, due to missing strides of sparse representation + result = [] + unroll_mat = unroll_mat.transpose(1,2) + features = features.transpose(1,2) + + #iterate over batches + for batch in range(batch_size): + mat = torch.matmul(unroll_mat[batch], features[batch]) + mat = torch.unsqueeze(mat, dim=0) + result.append(mat) + return torch.cat(result, dim=0).transpose(1,2) + + diff --git a/models/mesh_classifier.py b/models/mesh_classifier.py index 9ce50cb3..a72e694e 100644 --- a/models/mesh_classifier.py +++ b/models/mesh_classifier.py @@ -1,7 +1,7 @@ import torch from . import networks from os.path import join -from util.util import seg_accuracy, print_network +from util.util import seg_accuracy, print_network, mean_iou_calc class ClassifierModel: @@ -113,7 +113,8 @@ def test(self): label_class = self.labels self.export_segmentation(pred_class.cpu()) correct = self.get_accuracy(pred_class, label_class) - return correct, len(label_class) + mean_iou, iou =self.get_iou(pred_class, label_class) + return correct, len(label_class), mean_iou, iou def get_accuracy(self, pred, labels): """computes accuracy for classification / segmentation """ @@ -123,6 +124,12 @@ def get_accuracy(self, pred, labels): correct = seg_accuracy(pred, self.soft_label, self.mesh) return correct + def get_iou(self, pred, labels): + """computes IoU for segmentation """ + if self.opt.dataset_mode == 'segmentation': + mean_iou, iou = mean_iou_calc(pred, self.labels, self.nclasses) + return mean_iou, iou + def export_segmentation(self, pred_seg): if self.opt.dataset_mode == 'segmentation': for meshi, mesh in enumerate(self.mesh): diff --git a/models/networks.py b/models/networks.py index c2a13e2e..6b266c72 100644 --- a/models/networks.py +++ b/models/networks.py @@ -114,7 +114,8 @@ def define_loss(opt): if opt.dataset_mode == 'classification': loss = torch.nn.CrossEntropyLoss() elif opt.dataset_mode == 'segmentation': - loss = torch.nn.CrossEntropyLoss(ignore_index=-1) + weights = torch.FloatTensor(opt.weighted_loss) + loss = torch.nn.CrossEntropyLoss(weights, ignore_index=-1) return loss ############################################################################## diff --git a/options/base_options.py b/options/base_options.py index 61f21ce0..ea1febd9 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -2,7 +2,7 @@ import os from util import util import torch - +import sys class BaseOptions: def __init__(self): @@ -35,9 +35,26 @@ def initialize(self): self.parser.add_argument('--seed', type=int, help='if specified, uses seed') # visualization params self.parser.add_argument('--export_folder', type=str, default='', help='exports intermediate collapses to this folder') + # loss + self.parser.add_argument('--weighted_loss', nargs='+', default=[0.25, 0.25, 0.25, 0.25], type=float, help='Weights for loss') # self.initialized = True + def get_device(self): + if not self.initialized: + self.initialize() + self.opt, unknown = self.parser.parse_known_args() + + str_ids = self.opt.gpu_ids.split(',') + self.opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + self.opt.gpu_ids.append(id) + return self.opt.gpu_ids + + + def parse(self): if not self.initialized: self.initialize() @@ -84,3 +101,4 @@ def parse(self): opt_file.write('%s: %s\n' % (str(k), str(v))) opt_file.write('-------------- End ----------------\n') return self.opt + diff --git a/scripts/coseg_seg/get_data.sh b/scripts/coseg_seg/get_data.sh deleted file mode 100755 index b7673dbd..00000000 --- a/scripts/coseg_seg/get_data.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env bash - -DATADIR='datasets' #location where data gets downloaded to - -echo "downloading the data and putting it in: " $DATADIR -mkdir -p $DATADIR && cd $DATADIR -wget https://www.dropbox.com/s/34vy4o5fthhz77d/coseg.tar.gz -tar -xzvf coseg.tar.gz && rm coseg.tar.gz \ No newline at end of file diff --git a/scripts/coseg_seg/get_pretrained.sh b/scripts/coseg_seg/get_pretrained.sh deleted file mode 100755 index 562b7630..00000000 --- a/scripts/coseg_seg/get_pretrained.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash - -CHECKPOINT=checkpoints/coseg_aliens -mkdir -p $CHECKPOINT - -#gets the pretrained weights -wget https://www.dropbox.com/s/er7my13k9dwg9ii/coseg_aliens_wts.tar.gz -tar -xzvf coseg_aliens_wts.tar.gz && rm coseg_aliens_wts.tar.gz -mv latest_net.pth $CHECKPOINT -echo "downloaded pretrained weights to" $CHECKPOINT \ No newline at end of file diff --git a/scripts/coseg_seg/test.sh b/scripts/coseg_seg/test.sh deleted file mode 100755 index 422cf071..00000000 --- a/scripts/coseg_seg/test.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env bash - -## run the test and export collapses -python test.py \ ---dataroot datasets/coseg_aliens \ ---name coseg_aliens \ ---arch meshunet \ ---dataset_mode segmentation \ ---ncf 32 64 128 256 \ ---ninput_edges 2280 \ ---pool_res 1800 1350 600 \ ---resblocks 3 \ ---batch_size 12 \ ---export_folder meshes \ \ No newline at end of file diff --git a/scripts/coseg_seg/train.sh b/scripts/coseg_seg/train.sh deleted file mode 100755 index 143f18c7..00000000 --- a/scripts/coseg_seg/train.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env bash - -## run the training -python train.py \ ---dataroot datasets/coseg_aliens \ ---name coseg_aliens \ ---arch meshunet \ ---dataset_mode segmentation \ ---ncf 32 64 128 256 \ ---ninput_edges 2280 \ ---pool_res 1800 1350 600 \ ---resblocks 3 \ ---lr 0.001 \ ---batch_size 12 \ ---num_aug 20 \ ---slide_verts 0.2 \ - - -# -# python train.py --dataroot datasets/coseg_vases --name coseg_vases --arch meshunet --dataset_mode -segmentation --ncf 32 64 128 256 --ninput_edges 1500 --pool_res 1050 600 300 --resblocks 3 --lr 0.001 --batch_size 12 --num_aug 20 \ No newline at end of file diff --git a/scripts/coseg_seg/view.sh b/scripts/coseg_seg/view.sh deleted file mode 100755 index 08c80675..00000000 --- a/scripts/coseg_seg/view.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env bash - -python util/mesh_viewer.py \ ---files \ -checkpoints/coseg_aliens/meshes/142_0.obj \ -checkpoints/coseg_aliens/meshes/142_2.obj \ -checkpoints/coseg_aliens/meshes/142_3.obj \ \ No newline at end of file diff --git a/scripts/cubes/get_data.sh b/scripts/cubes/get_data.sh deleted file mode 100755 index ca594ce9..00000000 --- a/scripts/cubes/get_data.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/env bash - -DATADIR='datasets' #location where data gets downloaded to - -# get data -mkdir -p $DATADIR && cd $DATADIR -wget https://www.dropbox.com/s/2bxs5f9g60wa0wr/cubes.tar.gz -tar -xzvf cubes.tar.gz && rm cubes.tar.gz -echo "downloaded the data and put it in: " $DATADIR \ No newline at end of file diff --git a/scripts/cubes/get_pretrained.sh b/scripts/cubes/get_pretrained.sh deleted file mode 100755 index 2d3637e0..00000000 --- a/scripts/cubes/get_pretrained.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash - -CHECKPOINT='checkpoints/cubes' - -# get pretrained model -mkdir -p $CHECKPOINT -wget https://www.dropbox.com/s/fg7wum39bmlxr7w/cubes_wts.tar.gz -tar -xzvf cubes_wts.tar.gz && rm cubes_wts.tar.gz -mv latest_net.pth $CHECKPOINT -echo "downloaded pretrained weights to" $CHECKPOINT \ No newline at end of file diff --git a/scripts/cubes/test.sh b/scripts/cubes/test.sh deleted file mode 100755 index c8e31596..00000000 --- a/scripts/cubes/test.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env bash - -## run the test and export collapses -python test.py \ ---dataroot datasets/cubes \ ---name cubes \ ---ncf 64 128 256 256 \ ---pool_res 600 450 300 210 \ ---norm group \ ---resblocks 1 \ ---export_folder meshes \ \ No newline at end of file diff --git a/scripts/cubes/train.sh b/scripts/cubes/train.sh deleted file mode 100755 index 50f91f72..00000000 --- a/scripts/cubes/train.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env bash - -## run the training -python train.py \ ---dataroot datasets/cubes \ ---name cubes \ ---ncf 64 128 256 256 \ ---pool_res 600 450 300 210 \ ---norm group \ ---resblocks 1 \ ---flip_edges 0.2 \ ---slide_verts 0.2 \ ---num_aug 20 \ \ No newline at end of file diff --git a/scripts/cubes/view.sh b/scripts/cubes/view.sh deleted file mode 100755 index 452c3bc6..00000000 --- a/scripts/cubes/view.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env bash - -python util/mesh_viewer.py \ ---files checkpoints/cubes/meshes/horseshoe_4_0.obj \ -checkpoints/cubes/meshes/horseshoe_4_2.obj \ -checkpoints/cubes/meshes/horseshoe_4_3.obj \ -checkpoints/cubes/meshes/horseshoe_4_4.obj \ No newline at end of file diff --git a/scripts/human_seg/test.sh b/scripts/human_seg/test.sh old mode 100755 new mode 100644 index 02100f84..a896645c --- a/scripts/human_seg/test.sh +++ b/scripts/human_seg/test.sh @@ -1,14 +1,16 @@ #!/usr/bin/env bash ## run the test and export collapses -python test.py \ +python3 test.py \ --dataroot datasets/human_seg \ --name human_seg \ --arch meshunet \ --dataset_mode segmentation \ --ncf 32 64 128 256 \ ---ninput_edges 2280 \ ---pool_res 1800 1350 600 \ ---resblocks 3 \ +--ninput_edges 3000 \ +--pool_res 2000 1000 500 \ +--num_threads 0 \ +--resblocks 1 \ --batch_size 12 \ ---export_folder meshes \ \ No newline at end of file +--export_folder meshes \ +--gpu_ids -1 \ diff --git a/scripts/human_seg/train.sh b/scripts/human_seg/train.sh old mode 100755 new mode 100644 index 5e735a37..ac9459f2 --- a/scripts/human_seg/train.sh +++ b/scripts/human_seg/train.sh @@ -1,16 +1,20 @@ #!/usr/bin/env bash ## run the training -python train.py \ +python3 train.py \ --dataroot datasets/human_seg \ --name human_seg \ --arch meshunet \ --dataset_mode segmentation \ --ncf 32 64 128 256 \ ---ninput_edges 2280 \ ---pool_res 1800 1350 600 \ ---resblocks 3 \ ---batch_size 12 \ +--ninput_edges 3000 \ +--pool_res 2000 1000 500 \ +--num_threads 0 \ +--resblocks 1 \ +--batch_size 2 \ --lr 0.001 \ --num_aug 20 \ ---slide_verts 0.2 \ \ No newline at end of file +--slide_verts 0.2 \ +--gpu_ids -1 \ +--verbose_plot \ +--weighted_loss 0.125 0.125 0.125 0.125 0.125 0.125 0.125 0.125 \ diff --git a/scripts/shrec/get_data.sh b/scripts/shrec/get_data.sh deleted file mode 100755 index 4763e22a..00000000 --- a/scripts/shrec/get_data.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/env bash - -DATADIR='datasets' #location where data gets downloaded to - -# get data -mkdir -p $DATADIR && cd $DATADIR -wget https://www.dropbox.com/s/w16st84r6wc57u7/shrec_16.tar.gz -tar -xzvf shrec_16.tar.gz && rm shrec_16.tar.gz -echo "downloaded the data and putting it in: " $DATADIR diff --git a/scripts/shrec/get_pretrained.sh b/scripts/shrec/get_pretrained.sh deleted file mode 100755 index d5a229b9..00000000 --- a/scripts/shrec/get_pretrained.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/env bash - -CHECKPOINT='checkpoints/shrec16' - -mkdir -p $CHECKPOINT -wget https://www.dropbox.com/s/wqq1qxj4fjbpfas/shrec16_wts.tar.gz -tar -xzvf shrec16_wts.tar.gz && rm shrec16_wts.tar.gz -mv latest_net.pth $CHECKPOINT -echo "downloaded pretrained weights to" $CHECKPOINT \ No newline at end of file diff --git a/scripts/shrec/test.sh b/scripts/shrec/test.sh deleted file mode 100755 index 0d2c8b90..00000000 --- a/scripts/shrec/test.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env bash - -## run the test and export collapses -python test.py \ ---dataroot datasets/shrec_16 \ ---name shrec16 \ ---ncf 64 128 256 256 \ ---pool_res 600 450 300 180 \ ---norm group \ ---resblocks 1 \ ---export_folder meshes \ \ No newline at end of file diff --git a/scripts/shrec/train.sh b/scripts/shrec/train.sh deleted file mode 100755 index cecb2e1a..00000000 --- a/scripts/shrec/train.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env bash - -## run the training -python train.py \ ---dataroot datasets/shrec_16 \ ---name shrec16 \ ---ncf 64 128 256 256 \ ---pool_res 600 450 300 180 \ ---norm group \ ---resblocks 1 \ ---flip_edges 0.2 \ ---slide_verts 0.2 \ ---num_aug 20 \ ---niter_decay 100 \ \ No newline at end of file diff --git a/scripts/shrec/view.sh b/scripts/shrec/view.sh deleted file mode 100755 index 54bf19db..00000000 --- a/scripts/shrec/view.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env bash - -python util/mesh_viewer.py \ ---files \ -checkpoints/shrec16/meshes/T74_0.obj \ -checkpoints/shrec16/meshes/T74_3.obj \ -checkpoints/shrec16/meshes/T74_4.obj \ No newline at end of file diff --git a/test.py b/test.py index 15492f5b..eb5a8710 100644 --- a/test.py +++ b/test.py @@ -12,13 +12,15 @@ def run_test(epoch=-1): model = create_model(opt) writer = Writer(opt) # test - writer.reset_counter() + writer.reset_counter(opt) for i, data in enumerate(dataset): + #print(i) model.set_input(data) - ncorrect, nexamples = model.test() - writer.update_counter(ncorrect, nexamples) + ncorrect, nexamples, mean_iou, iou = model.test() + writer.update_counter(ncorrect, nexamples, mean_iou, iou) writer.print_acc(epoch, writer.acc) - return writer.acc + writer.print_iou(epoch, writer.mean_iou, writer.seg_iou) + return writer.acc, writer.mean_iou, writer.iou if __name__ == '__main__': diff --git a/train.py b/train.py index 41b326b7..1823cb55 100644 --- a/train.py +++ b/train.py @@ -5,6 +5,7 @@ from util.writer import Writer from test import run_test + if __name__ == '__main__': opt = TrainOptions().parse() dataset = DataLoader(opt) @@ -40,7 +41,6 @@ (epoch, total_steps)) model.save_network('latest') - iter_data_time = time.time() if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) diff --git a/util/util.py b/util/util.py index 562c22f6..2e5c9cc0 100644 --- a/util/util.py +++ b/util/util.py @@ -4,23 +4,29 @@ import os +# from torch_scatter import scatter_add + def mkdir(path): if not os.path.exists(path): os.makedirs(path) + MESH_EXTENSIONS = [ '.obj', ] + def is_mesh_file(filename): return any(filename.endswith(extension) for extension in MESH_EXTENSIONS) + def pad(input_arr, target_length, val=0, dim=1): shp = input_arr.shape npad = [(0, 0) for _ in range(len(shp))] npad[dim] = (0, target_length - shp[dim]) return np.pad(input_arr, pad_width=npad, mode='constant', constant_values=val) + def seg_accuracy(predicted, ssegs, meshes): correct = 0 ssegs = ssegs.squeeze(-1) @@ -31,6 +37,36 @@ def seg_accuracy(predicted, ssegs, meshes): correct += (correct_vec.float() * edge_areas).sum() return correct + +def intersection_over_union(preds, target, num_classes): + preds, target = torch.nn.functional.one_hot(preds, num_classes), torch.nn.functional.one_hot(target, num_classes) + iou = torch.zeros(num_classes, dtype=torch.float32) + for idx, pred in enumerate(preds): + i = (pred & target[idx]).sum(dim=0) + u = (pred | target[idx]).sum(dim=0) + iou = iou.add(i.cpu().to(torch.float) / u.cpu().to(torch.float)) + return iou + + +def mean_iou_calc(pred, target, num_classes): + #Removal of padded labels marked with -1 + slimpred = [] + slimtarget = [] + + for batch in range(pred.shape[0]): + if (target[batch] == -1).any(): + slimLabels = target[batch][target[batch]!=-1] + slimtarget.append(slimLabels) + slimpred.append(pred[batch][:slimLabels.size()[0]]) + + pred = torch.stack(slimpred,0) + target = torch.stack(slimtarget, 0) + + iou = intersection_over_union(pred, target, num_classes) + mean_iou = iou.mean(dim=-1) + return mean_iou, iou + + def print_network(net): """Print the total number of parameters in the network Parameters: @@ -43,11 +79,12 @@ def print_network(net): print('[Network] Total number of parameters : %.3f M' % (num_params / 1e6)) print('-----------------------------------------------') + def get_heatmap_color(value, minimum=0, maximum=1): minimum, maximum = float(minimum), float(maximum) - ratio = 2 * (value-minimum) / (maximum - minimum) - b = int(max(0, 255*(1 - ratio))) - r = int(max(0, 255*(ratio - 1))) + ratio = 2 * (value - minimum) / (maximum - minimum) + b = int(max(0, 255 * (1 - ratio))) + r = int(max(0, 255 * (ratio - 1))) g = 255 - b - r return r, g, b @@ -66,3 +103,38 @@ def calculate_entropy(np_array): entropy -= a * np.log(a) entropy /= np.log(np_array.shape[0]) return entropy + + +def pad_with(vector, pad_width, iaxis, kwargs): + pad_value = kwargs.get('padder', 10) + vector[:pad_width[0]] = pad_value + vector[-pad_width[1]:] = pad_value + + + + + +def myindexrowselect(groups, mask_index, device): + + sparseIndices = groups._indices() + newIndices = [] + + for i, value in enumerate(mask_index): + #Get index from relevant indices + index = (sparseIndices[0] == value).nonzero() + + #Get rows by index + sparseRow = [sparseIndices[:, value] for value in index] + sparseRow = torch.cat(sparseRow,1)[1] + singleRowIndices = torch.squeeze(torch.full((1,len(sparseRow)),i, dtype=torch.long),0).to(sparseRow.device) + indices = torch.stack((singleRowIndices,sparseRow)) + newIndices.append(indices) + + allNewIndices = torch.cat(newIndices,1) + + #Create new tensor + groups = torch.sparse_coo_tensor(indices=allNewIndices, + values=torch.ones(allNewIndices.shape[1], dtype=torch.float), + size=(len(mask_index), groups.shape[1])) + + return groups diff --git a/util/writer.py b/util/writer.py index 6c0b4759..74df714f 100644 --- a/util/writer.py +++ b/util/writer.py @@ -1,5 +1,7 @@ import os import time +import torch +import numpy as np try: from tensorboardX import SummaryWriter @@ -7,6 +9,7 @@ print('tensorboard X not installed, visualizing wont be available') SummaryWriter = None + class Writer: def __init__(self, opt): self.name = opt.name @@ -17,7 +20,10 @@ def __init__(self, opt): self.start_logs() self.nexamples = 0 self.ncorrect = 0 - # + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.iou = torch.zeros([opt.nclasses]).to(self.device) + self.avg_iou = 0 + if opt.is_train and not opt.no_vis and SummaryWriter is not None: self.display = SummaryWriter(comment=opt.name) else: @@ -47,6 +53,10 @@ def plot_loss(self, loss, epoch, i, n): if self.display: self.display.add_scalar('data/train_loss', loss, iters) + def plot_lr(self, lr, epoch): + if self.display: + self.display.add_scalar('data/lr', lr, epoch) + def plot_model_wts(self, model, epoch): if self.opt.is_train and self.display: for name, param in model.net.named_parameters(): @@ -54,31 +64,51 @@ def plot_model_wts(self, model, epoch): def print_acc(self, epoch, acc): """ prints test accuracy to terminal / file """ - message = 'epoch: {}, TEST ACC: [{:.5} %]\n' \ + message = 'epoch: {}, TEST ACC: [{:.5} %]' \ .format(epoch, acc * 100) print(message) with open(self.testacc_log, "a") as log_file: log_file.write('%s\n' % message) + def print_iou(self, epoch, avg_iou, iou): + """ prints test accuracy to terminal / file """ + message = 'epoch: {}, TEST MEAN_IOU: [{:.5} %] \nepoch: {}, IOU: {}\n' \ + .format(epoch, avg_iou * 100, epoch, iou.cpu().numpy() * 100) + print(message) + with open(self.testacc_log, "a") as log_file: + log_file.write('%s' % message) + def plot_acc(self, acc, epoch): if self.display: - self.display.add_scalar('data/test_acc', acc, epoch) + self.display.add_scalar('data/test_acc', acc[0], epoch) - def reset_counter(self): + def reset_counter(self, opt): """ counts # of correct examples """ self.ncorrect = 0 self.nexamples = 0 + self.iou = torch.zeros([opt.nclasses]) + self.avg_iou = 0 - def update_counter(self, ncorrect, nexamples): + def update_counter(self, ncorrect, nexamples, avg_iou, iou): self.ncorrect += ncorrect self.nexamples += nexamples + self.iou = torch.stack([self.iou.to(self.device), iou.to(self.device)]).sum(dim=0) + self.avg_iou += avg_iou @property def acc(self): return float(self.ncorrect) / self.nexamples + @property + def mean_iou(self): + return (self.avg_iou / self.nexamples) + + @property + def seg_iou(self): + return (self.iou / self.nexamples) + def close(self): if self.display is not None: self.display.close()