Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Install python package apex for distributed training
```
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir ./
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```

## Supported model weights
Expand All @@ -48,10 +48,29 @@ Move the downloaded weights to weights/

## Preparation of dataset - Cityscapes

Please download the dataset from the officit site - [Download](https://www.cityscapes-dataset.com/)
Please download the dataset from the official site.

This dataset requires you to download the source data manually:
You have to download files from - [Download](https://www.cityscapes-dataset.com/) (This dataset requires registration). The config file is written for leftImg8bit_sequence_trainvaltest.zip and the fine annotations files from gtFine_trainvaltest.zip. To use other datasets, you require additional configa, which are not included.

Aternatively, For downloading using command line, as shared by [cemsaz](https://github.com/cemsaz/city-scapes-script),

Use below cmd by specifying your username and password,

```
wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=myusername&password=mypassword&submit=Login' https://www.cityscapes-dataset.com/login/
```

and provide the packageID of the required zip file.

```
wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1
```

Hint : You can get the package id from the download link of the file you need to download. In our case, for leftImg8bit_sequence_trainvaltest.zip and gtFine_trainvaltest.zip, it is packageID=14 & 1.

```
data_path = './data/leftImg8bit_sequence_trainvaltest_2K/'
data_path = './data/cityscapes'
```

Modify the data_path in config/cityscapes.py
Expand All @@ -60,7 +79,7 @@ Modify the data_path in config/cityscapes.py

```
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --master_port 1111 \
--nproc_per_node 2 python main.py --segnet <segnet_name> --dataset <dataset_name> \
--nproc_per_node 2 main.py --segnet <segnet_name> --dataset <dataset_name> \
--optical-flow-network <of_name> --checkname <SAVE_DIR>
```

Expand Down
6 changes: 3 additions & 3 deletions config/camvid.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class camvid_config(object):
bsnet_weight_path = './weights/cityscapes-bisenet-R18.pth'


#resume_path = './weights/gsvnet_bisenet_r18.tar'
resume_path = './weights/gsvnet_swnet_r18.tar'
bisenet_resume_path = './weights/gsvnet_bisenet_r18.tar'
swnet_resume_path = './weights/gsvnet_swnet_r18.tar'

optical_flow_network_path = './weights/flownet.pth.tar'
data_path = '' #put your data path here
data_path = './data/camvid' #put your data path here
6 changes: 3 additions & 3 deletions config/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class cityscapes_config(object):
] # 19 classes + 1 void class
swnet_weight_path = './weights/cityscapes-swnet-R18.pt'
bsnet_weight_path = './weights/cityscapes-bisenet-R18.pth'
#resume_path = './weights/gsvnet_bisenet_r18.tar'
resume_path = './weights/gsvnet_swnet_r18.tar'
bisenet_resume_path = './weights/gsvnet_bisenet_r18.tar'
swnet_resume_path = './weights/gsvnet_swnet_r18.tar'
optical_flow_network_path = './weights/flownet.pth.tar'
data_path = '' # put your data path here
data_path = './data/cityscapes' # put your data path here
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ numpy
Pillow
scikit-learn
tensorboardX
torch==1.1.0
torchvision==0.2.1
torch>=1.1.0
torchvision>=0.2.1
tqdm
opencv-python
7 changes: 4 additions & 3 deletions trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __init__(self, args, name):

if args.dataset.find("cityscapes") >= 0:
self.dataset_config = cityscapes_config()
elif args.dataset_name == "camvid":
elif args.dataset == "camvid":
self.dataset_config = camvid_config()
else:
raise NotImplementedError("Trainer dataset %s is not registered into the system" % args.dataset)
Expand All @@ -173,7 +173,8 @@ def __init__(self, args, name):
self.num_classes = self.dataset_config.num_classes
self.swnet_weight_path = self.dataset_config.swnet_weight_path
self.bsnet_weight_path = self.dataset_config.bsnet_weight_path
self.resume_path = self.dataset_config.resume_path
self.swnet_resume_path = self.dataset_config.swnet_resume_path
self.bisenet_resume_path = self.dataset_config.bisenet_resume_path
self.optical_flow_network_path = self.dataset_config.optical_flow_network_path
self.data_path = self.dataset_config.data_path

Expand Down Expand Up @@ -526,4 +527,4 @@ def get_images_from_lines(self, path, lines, return_dict):
for frame in range(gt_frame_num-4,gt_frame_num+1):
frame_name = line[:-22] + ( "%06d" % (frame) ) + line[-16:]
frame_path = os.path.join(path, frame_name)
return_dict[frame_path] = image_loader(frame_path)
return_dict[frame_path] = image_loader(frame_path)
4 changes: 3 additions & 1 deletion trainer/gsvnet_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ def _img_input_transform(self, args):

def _val_img_input_transform(self, args):
if self.args.segnet == 'bisenet':
self.resume_path = self.bisenet_resume_path
mean_std = ([0.406, 0.456, 0.485], [0.225, 0.224, 0.229])
img_transform = standard_transforms.Compose([
FlipChannels(),
standard_transforms.ToTensor(),
standard_transforms.Normalize(*mean_std)
])
elif self.args.segnet == 'swiftnet':
self.resume_path = self.swnet_resume_path
mean_std = ([72.3, 82.90, 73.15],[47.73, 48.49, 47.67])
img_transform = standard_transforms.Compose([
FlipChannels(),
Expand Down Expand Up @@ -478,4 +480,4 @@ def load_state(self, args):
self.test_best_pred = checkpoint['test_best_pred']

print("=> loaded checkpoint '{}' (epoch {})"
.format(self.resume_path, checkpoint['epoch']))
.format(self.resume_path, checkpoint['epoch']))