diff --git a/configs/simclr_cifar.yaml b/configs/simclr_cifar.yaml index 9f7a491..3d33c3c 100644 --- a/configs/simclr_cifar.yaml +++ b/configs/simclr_cifar.yaml @@ -16,9 +16,9 @@ train: warmup_lr: 0 base_lr: 0.3 final_lr: 0 - num_epochs: 800 # this parameter influence the lr decay + num_epochs: 200 # this parameter influence the lr decay stop_at_epoch: 100 # has to be smaller than num_epochs - batch_size: 256 + batch_size: 512 knn_monitor: False # knn monitor will take more time knn_interval: 1 knn_k: 200 diff --git a/configs/simsiam_cifar.yaml b/configs/simsiam_cifar.yaml index 61f6454..d0792e7 100644 --- a/configs/simsiam_cifar.yaml +++ b/configs/simsiam_cifar.yaml @@ -43,8 +43,4 @@ logger: seed: null # None type for yaml file # two things might lead to stochastic behavior other than seed: # worker_init_fn from dataloader and torch.nn.functional.interpolate -# (keep this in mind if you want to achieve 100% deterministic) - - - - +# (keep this in mind if you want to achieve 100% deterministic) \ No newline at end of file diff --git a/configs/simsiam_cifar_eval_sgd.yaml b/configs/simsiam_cifar_eval_sgd.yaml index 648a873..a443a3b 100644 --- a/configs/simsiam_cifar_eval_sgd.yaml +++ b/configs/simsiam_cifar_eval_sgd.yaml @@ -31,7 +31,3 @@ seed: null # None type for yaml file # two things might lead to stochastic behavior other than seed: # worker_init_fn from dataloader and torch.nn.functional.interpolate # (keep this in mind if you want to achieve 100% deterministic) - - - - diff --git a/configs/simsiam_image100_eval.yaml b/configs/simsiam_image100_eval.yaml new file mode 100644 index 0000000..a1a16b7 --- /dev/null +++ b/configs/simsiam_image100_eval.yaml @@ -0,0 +1,37 @@ +name: simsiam-imagenet100-experiment-resnet50 +dataset: + name: imagenet100 + image_size: 224 + num_workers: 4 + +model: + name: simsiam + backbone: resnet50 + proj_layers: 2 + +train: null + +eval: # linear evaluation, False will turn off automatic evaluation after training + optimizer: + name: sgd + weight_decay: 0 + momentum: 0.9 + warmup_lr: 0 + warmup_epochs: 0 + base_lr: 10 #30 + final_lr: 0 + batch_size: 128 + num_epochs: 60 + +logger: + tensorboard: False + matplotlib: False + +seed: null # None type for yaml file +# two things might lead to stochastic behavior other than seed: +# worker_init_fn from dataloader and torch.nn.functional.interpolate +# (keep this in mind if you want to achieve 100% deterministic) + + + + diff --git a/configs/simsiam_imagenet.yaml b/configs/simsiam_imagenet.yaml new file mode 100644 index 0000000..d41052d --- /dev/null +++ b/configs/simsiam_imagenet.yaml @@ -0,0 +1,46 @@ +name: simsiam-imagenet-experiment-resnet50 +dataset: + name: imagenet + image_size: 224 + num_workers: 16 + +model: + name: simsiam + backbone: resnet50 + proj_layers: 2 + +train: + optimizer: + name: sgd + weight_decay: 0.0002 + momentum: 0.9 + warmup_epochs: 0 + warmup_lr: 0 + base_lr: 0.05 + final_lr: 0 + num_epochs: 200 # this parameter influence the lr decay + stop_at_epoch: 200 # has to be smaller than num_epochs + batch_size: 512 + knn_monitor: False # knn monitor will take more time + knn_interval: 50 + knn_k: 200 +eval: # linear evaluation, False will turn off automatic evaluation after training + optimizer: + name: sgd + weight_decay: 0 + momentum: 0.9 + warmup_lr: 0 + warmup_epochs: 0 + base_lr: 30 + final_lr: 0 + batch_size: 256 + num_epochs: 100 + +logger: + tensorboard: True + matplotlib: True + +seed: null # None type for yaml file +# two things might lead to stochastic behavior other than seed: +# worker_init_fn from dataloader and torch.nn.functional.interpolate +# (keep this in mind if you want to achieve 100% deterministic) diff --git a/configs/simsiam_imagenet100.yaml b/configs/simsiam_imagenet100.yaml new file mode 100644 index 0000000..778f542 --- /dev/null +++ b/configs/simsiam_imagenet100.yaml @@ -0,0 +1,50 @@ +name: simsiam-imagenet100-experiment-resnet50 +dataset: + name: imagenet100 + image_size: 224 + num_workers: 8 + +model: + name: simsiam + backbone: resnet50 + proj_layers: 2 + +train: + optimizer: + name: sgd + weight_decay: 0.0001 + momentum: 0.9 + warmup_epochs: 10 + warmup_lr: 0 + base_lr: 0.05 + final_lr: 0 + num_epochs: 200 # this parameter influence the lr decay + stop_at_epoch: 200 # has to be smaller than num_epochs + batch_size: 512 + knn_monitor: True # knn monitor will take more time + knn_interval: 40 + knn_k: 200 +eval: # linear evaluation, False will turn off automatic evaluation after training + optimizer: + name: sgd + weight_decay: 0 + momentum: 0.9 + warmup_lr: 0 + warmup_epochs: 0 + base_lr: 30 + final_lr: 0 + batch_size: 256 + num_epochs: 100 + +logger: + tensorboard: True + matplotlib: True + +seed: null # None type for yaml file +# two things might lead to stochastic behavior other than seed: +# worker_init_fn from dataloader and torch.nn.functional.interpolate +# (keep this in mind if you want to achieve 100% deterministic) + + + + diff --git a/configs/simsiam_imagenet_eval.yaml b/configs/simsiam_imagenet_eval.yaml new file mode 100644 index 0000000..9d84e5a --- /dev/null +++ b/configs/simsiam_imagenet_eval.yaml @@ -0,0 +1,37 @@ +name: simsiam-imagenet100-experiment-resnet50 +dataset: + name: imagenet + image_size: 224 + num_workers: 8 + +model: + name: simsiam + backbone: resnet50 + proj_layers: 2 + +train: null + +eval: # linear evaluation, False will turn off automatic evaluation after training + optimizer: + name: sgd + weight_decay: 0 + momentum: 0.9 + warmup_lr: 0 + warmup_epochs: 0 + base_lr: 30 + final_lr: 0 + batch_size: 256 + num_epochs: 100 + +logger: + tensorboard: False + matplotlib: False + +seed: null # None type for yaml file +# two things might lead to stochastic behavior other than seed: +# worker_init_fn from dataloader and torch.nn.functional.interpolate +# (keep this in mind if you want to achieve 100% deterministic) + + + + diff --git a/datasets/__init__.py b/datasets/__init__.py index f6ab42b..e4e7754 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -3,17 +3,23 @@ from .random_dataset import RandomDataset -def get_dataset(dataset, data_dir, transform, train=True, download=False, debug_subset_size=None): +def get_dataset(dataset, data_dir, transform, train=True, download=True, debug_subset_size=None): if dataset == 'mnist': dataset = torchvision.datasets.MNIST(data_dir, train=train, transform=transform, download=download) elif dataset == 'stl10': dataset = torchvision.datasets.STL10(data_dir, split='train+unlabeled' if train else 'test', transform=transform, download=download) elif dataset == 'cifar10': - dataset = torchvision.datasets.CIFAR10(data_dir, train=train, transform=transform, download=download) + dataset = torchvision.datasets.CIFAR10(data_dir, train=train, transform=transform, download=True) elif dataset == 'cifar100': dataset = torchvision.datasets.CIFAR100(data_dir, train=train, transform=transform, download=download) - elif dataset == 'imagenet': - dataset = torchvision.datasets.ImageNet(data_dir, split='train' if train == True else 'val', transform=transform, download=download) + elif dataset == 'imagenet' and train == True: + dataset = torchvision.datasets.ImageFolder(data_dir+'train', transform=transform) + elif dataset == 'imagenet' and train == False: + dataset = torchvision.datasets.ImageFolder(data_dir+'val', transform=transform) + elif dataset == 'imagenet100' and train == True: + dataset = torchvision.datasets.ImageFolder(data_dir+'train', transform=transform) + elif dataset == 'imagenet100' and train == False: + dataset = torchvision.datasets.ImageFolder(data_dir+'val', transform=transform) elif dataset == 'random': dataset = RandomDataset() else: diff --git a/linear_eval.py b/linear_eval.py index 6b63095..4c609ad 100644 --- a/linear_eval.py +++ b/linear_eval.py @@ -11,50 +11,55 @@ from datasets import get_dataset from optimizers import get_optimizer, LR_Scheduler -def main(args): - - train_loader = torch.utils.data.DataLoader( - dataset=get_dataset( +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +def main(gpu, args): + rank = args.nr * args.gpus + gpu + dist.init_process_group("nccl", rank=rank, world_size=args.world_size) + torch.manual_seed(0) + torch.cuda.set_device(gpu) + train_dataset = get_dataset( transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs), train=True, **args.dataset_kwargs - ), - batch_size=args.eval.batch_size, - shuffle=True, + ) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=args.world_size, rank=rank) + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=(args.eval.batch_size//args.gpus), + shuffle=False, + sampler = train_sampler, **args.dataloader_kwargs ) - test_loader = torch.utils.data.DataLoader( - dataset=get_dataset( + test_dataset = get_dataset( transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), train=False, **args.dataset_kwargs - ), - batch_size=args.eval.batch_size, + ) + test_loader = torch.utils.data.DataLoader( + dataset=test_dataset, + batch_size=(args.eval.batch_size//args.gpus), shuffle=False, **args.dataloader_kwargs ) - - - model = get_backbone(args.model.backbone) - classifier = nn.Linear(in_features=model.output_dim, out_features=10, bias=True).to(args.device) - + model = get_backbone(args.model.backbone) + classifier = nn.Linear(in_features=model.output_dim, out_features=100, bias=True).to(args.device) assert args.eval_from is not None save_dict = torch.load(args.eval_from, map_location='cpu') msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True) - - # print(msg) - model = model.to(args.device) - model = torch.nn.DataParallel(model) - # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier) - classifier = torch.nn.DataParallel(classifier) + model = model.to(args.device) + model = DDP(model, device_ids=[gpu], find_unused_parameters=True) + classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier) + classifier = DDP(classifier, device_ids=[gpu], find_unused_parameters=True) # define optimizer optimizer = get_optimizer( args.eval.optimizer.name, classifier, lr=args.eval.base_lr*args.eval.batch_size/256, momentum=args.eval.optimizer.momentum, weight_decay=args.eval.optimizer.weight_decay) - # define lr scheduler lr_scheduler = LR_Scheduler( optimizer, @@ -62,10 +67,8 @@ def main(args): args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256, len(train_loader), ) - loss_meter = AverageMeter(name='Loss') acc_meter = AverageMeter(name='Accuracy') - # Start training global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating') for epoch in global_progress: @@ -73,52 +76,36 @@ def main(args): model.eval() classifier.train() local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True) - for idx, (images, labels) in enumerate(local_progress): - classifier.zero_grad() with torch.no_grad(): feature = model(images.to(args.device)) - preds = classifier(feature) - loss = F.cross_entropy(preds, labels.to(args.device)) - loss.backward() optimizer.step() loss_meter.update(loss.item()) lr = lr_scheduler.step() local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg}) - - classifier.eval() - correct, total = 0, 0 - acc_meter.reset() - for idx, (images, labels) in enumerate(test_loader): - with torch.no_grad(): - feature = model(images.to(args.device)) - preds = classifier(feature).argmax(dim=1) - correct = (preds == labels.to(args.device)).sum().item() - acc_meter.update(correct/preds.shape[0]) - print(f'Accuracy = {acc_meter.avg*100:.2f}') - - - - + if gpu==0 and (epoch+1) == (args.eval.num_epochs-1): + print('epoch:',epoch+1) + classifier.eval() + correct, total = 0, 0 + acc_meter.reset() + if gpu == 0: + for idx, (images, labels) in enumerate(test_loader): + with torch.no_grad(): + feature = model(images.to(args.device)) + preds = classifier(feature).argmax(dim=1) + correct = (preds == labels.to(args.device)).sum().item() + acc_meter.update(correct/preds.shape[0]) + print(f'Accuracy = {acc_meter.avg*100:.2f}') + break + + dist.destroy_process_group() if __name__ == "__main__": - main(args=get_args()) - - - - - - - - - - - - - - - - + args = get_args() + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "3367" + args.world_size = args.gpus * args.nodes + mp.spawn(main, args=(args,), nprocs=args.gpus, join=True) diff --git a/main.py b/main.py index 4dd5380..58f387f 100644 --- a/main.py +++ b/main.py @@ -14,39 +14,56 @@ from linear_eval import main as linear_eval from datetime import datetime -def main(device, args): +# distributed training +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +def cleanup(): + dist.destroy_process_group() + +def main(gpu, args): + rank = args.nr * args.gpus + gpu + dist.init_process_group("nccl", rank=rank, world_size=args.world_size) + + torch.manual_seed(0) + torch.cuda.set_device(gpu) + + train_dataset = get_dataset(transform=get_aug(train=True, **args.aug_kwargs), train=True, **args.dataset_kwargs) + + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=args.world_size, rank=rank) train_loader = torch.utils.data.DataLoader( - dataset=get_dataset( - transform=get_aug(train=True, **args.aug_kwargs), - train=True, - **args.dataset_kwargs), - shuffle=True, - batch_size=args.train.batch_size, + dataset=train_dataset, + shuffle=False, + batch_size=(args.train.batch_size // args.gpus), + sampler = train_sampler, **args.dataloader_kwargs ) + + memory_dataset = get_dataset(transform=get_aug(train=False,train_classifier=False, **args.aug_kwargs), train=True, **args.dataset_kwargs) + + memory_loader = torch.utils.data.DataLoader( - dataset=get_dataset( - transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), - train=True, - **args.dataset_kwargs), + dataset=memory_dataset, shuffle=False, - batch_size=args.train.batch_size, + batch_size=(args.train.batch_size // args.gpus), **args.dataloader_kwargs ) + + test_datset = get_dataset( transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), train=False,**args.dataset_kwargs) + + test_loader = torch.utils.data.DataLoader( - dataset=get_dataset( - transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), - train=False, - **args.dataset_kwargs), + dataset= test_datset, shuffle=False, - batch_size=args.train.batch_size, + batch_size=(args.train.batch_size // args.gpus), **args.dataloader_kwargs ) - # define model - model = get_model(args.model).to(device) - model = torch.nn.DataParallel(model) + model = get_model(args.model).cuda(gpu) + + model = DDP(model, device_ids=[gpu], find_unused_parameters=True) # define optimizer optimizer = get_optimizer( @@ -62,58 +79,63 @@ def main(device, args): len(train_loader), constant_predictor_lr=True # see the end of section 4.2 predictor ) - - logger = Logger(tensorboard=args.logger.tensorboard, matplotlib=args.logger.matplotlib, log_dir=args.log_dir) + if gpu ==0: + logger = Logger(tensorboard=args.logger.tensorboard, matplotlib=args.logger.matplotlib, log_dir=args.log_dir) accuracy = 0 # Start training global_progress = tqdm(range(0, args.train.stop_at_epoch), desc=f'Training') for epoch in global_progress: model.train() - + local_progress=tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress) for idx, ((images1, images2), labels) in enumerate(local_progress): model.zero_grad() - data_dict = model.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True)) - loss = data_dict['loss'].mean() # ddp + data_dict = model.forward(images1.cuda(non_blocking=True), images2.cuda(non_blocking=True)) + loss = data_dict['loss'] # ddp loss.backward() optimizer.step() lr_scheduler.step() data_dict.update({'lr':lr_scheduler.get_lr()}) - local_progress.set_postfix(data_dict) - logger.update_scalers(data_dict) + if gpu ==0: + logger.update_scalers(data_dict) - if args.train.knn_monitor and epoch % args.train.knn_interval == 0: - accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, device, k=min(args.train.knn_k, len(memory_loader.dataset)), hide_progress=args.hide_progress) + if args.train.knn_monitor and epoch % args.train.knn_interval == 0 and gpu==0: + accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, gpu, k=min(args.train.knn_k, len(memory_loader.dataset)), hide_progress=args.hide_progress) epoch_dict = {"epoch":epoch, "accuracy":accuracy} global_progress.set_postfix(epoch_dict) - logger.update_scalers(epoch_dict) - - # Save checkpoint - model_path = os.path.join(args.ckpt_dir, f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth") # datetime.now().strftime('%Y%m%d_%H%M%S') - torch.save({ - 'epoch': epoch+1, - 'state_dict':model.module.state_dict() - }, model_path) - print(f"Model saved to {model_path}") - with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f: - f.write(f'{model_path}') - if args.eval is not False: - args.eval_from = model_path - linear_eval(args) + if gpu == 0: + logger.update_scalers(epoch_dict) + + + if gpu == 0 : + model_path = os.path.join(args.ckpt_dir, f"{args.name}_final.pth") # datetime.now().strftime('%Y%m%d_%H%M%S') + torch.save({ + 'epoch': args.train.stop_at_epoch, + 'state_dict':model.module.state_dict() + }, model_path) + print(f"Final Model saved to {model_path}") + with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f: + f.write(f'{model_path}') + if __name__ == "__main__": args = get_args() - main(device=args.device, args=args) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "8100" + args.world_size = args.gpus * args.nodes - completed_log_dir = args.log_dir.replace('in-progress', 'debug' if args.debug else 'completed') + # Initialize the process and join up with the other processes. + # This is “blocking,” meaning that no process will continue until all processes have joined. + mp.spawn(main, args=(args,), nprocs=args.gpus, join=True) + completed_log_dir = args.log_dir.replace('in-progress', 'debug' if args.debug else 'completed') os.rename(args.log_dir, completed_log_dir) print(f'Log file has been saved to {completed_log_dir}')