基于弱监督学习的密集标签生成

网友投稿 970 2022-11-23

基于弱监督学习的密集标签生成

基于弱监督学习的密集标签生成

文章目录

​​Introduction​​​​CODE​​​​参考​​

Introduction

针对上述问题,可借助弱监督学习来生成密集标注样本,该方法生成的样本更加符合裸地的实际边缘信息,并且大大减少标注样本的时间。样本如下图所示,其中c是通过弱监督学习自动生成的。

CODE

#from __future__ import print_functionimport argparseimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transformsfrom torch.autograd import Variableimport cv2import sysimport numpy as npimport torch.nn.initimport globimport randomuse_cuda = torch.cuda.is_available()parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation')parser.add_argument('--scribble', action='store_true', default=True, help='use scribbles')parser.add_argument('--nChannel', metavar='N', default=100, type=int, help='number of channels')parser.add_argument('--maxIter', metavar='T', default=150, type=int, help='number of maximum iterations')parser.add_argument('--minLabels', metavar='minL', default=3, type=int, help='minimum number of labels')parser.add_argument('--lr', metavar='LR', default=0.1, type=float, help='learning rate')parser.add_argument('--nConv', metavar='M', default=2, type=int, help='number of convolutional layers')parser.add_argument('--visualize', metavar='1 or 0', default=1, type=int, help='visualization flag')parser.add_argument('--input', metavar='FILENAME', help='input image file root path', default='bareland2/train_images')parser.add_argument('--stepsize_sim', metavar='SIM', default=1, type=float, help='step size for similarity loss', required=False)parser.add_argument('--stepsize_con', metavar='CON', default=1, type=float, help='step size for continuity loss')parser.add_argument('--stepsize_scr', metavar='SCR', default=0.5, type=float, help='step size for scribble loss')args = parser.parse_args()# CNN modelclass MyNet(nn.Module): def __init__(self,input_dim): super(MyNet, self).__init__() self.conv1 = nn.Conv2d(input_dim, args.nChannel, kernel_size=3, stride=1, padding=1 ) self.bn1 = nn.BatchNorm2d(args.nChannel) self.conv2 = nn.ModuleList() self.bn2 = nn.ModuleList() for i in range(args.nConv-1): self.conv2.append( nn.Conv2d(args.nChannel, args.nChannel, kernel_size=3, stride=1, padding=1 ) ) self.bn2.append( nn.BatchNorm2d(args.nChannel) ) self.conv3 = nn.Conv2d(args.nChannel, args.nChannel, kernel_size=1, stride=1, padding=0 ) self.bn3 = nn.BatchNorm2d(args.nChannel) def forward(self, x): x = self.conv1(x) x = F.relu( x ) x = self.bn1(x) for i in range(args.nConv-1): x = self.conv2[i](x) x = F.relu( x ) x = self.bn2[i](x) x = self.conv3(x) x = self.bn3(x) return xif __name__ == '__main__': import gdalTools import os imglist = glob.glob(f'{args.input}/*.tif') outRoot = args.input.replace("images", "labels") outRGBRoot = args.input.replace("images", "labels_rgb") gdalTools.mkdir(outRoot) gdalTools.mkdir(outRGBRoot) for imgPath in imglist: baseName = os.path.basename(imgPath) # load image im = cv2.imread(imgPath) data = torch.from_numpy(np.array([im.transpose( (2, 0, 1) ).astype('float32')/255.])) if use_cuda: data = data.cuda() data = Variable(data) # load scribble if args.scribble: scribblePath = imgPath.replace('images', 'scribbles') assert os.path.exists(scribblePath), f'please check your scribblePath:{scribblePath}' mask = cv2.imread(scribblePath, -1) mask = mask.reshape(-1) mask_inds = np.unique(mask) mask_inds = np.delete(mask_inds, np.argwhere(mask_inds==255) ) inds_sim = torch.from_numpy(np.where(mask == 255)[ 0 ]) inds_scr = torch.from_numpy(np.where(mask != 255)[ 0 ]) target_scr = torch.from_numpy( mask.astype(np.int)) if use_cuda: inds_sim = inds_sim.cuda() inds_scr = inds_scr.cuda() target_scr = target_scr.cuda() target_scr = Variable( target_scr ) # set minLabels args.minLabels = len(mask_inds) # train model = MyNet( data.size(1) ) if use_cuda: model.cuda() model.train() # similarity loss definition loss_fn = torch.nn.CrossEntropyLoss() # scribble loss definition loss_fn_scr = torch.nn.CrossEntropyLoss() # continuity loss definition loss_hpy = torch.nn.L1Loss(size_average = True) loss_hpz = torch.nn.L1Loss(size_average = True) HPy_target = torch.zeros(im.shape[0]-1, im.shape[1], args.nChannel) HPz_target = torch.zeros(im.shape[0], im.shape[1]-1, args.nChannel) if use_cuda: HPy_target = HPy_target.cuda() HPz_target = HPz_target.cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) label_colours = np.random.randint(255,size=(100,3)) for batch_idx in range(args.maxIter): # forwarding optimizer.zero_grad() output = model( data )[ 0 ] output = output.permute( 1, 2, 0 ).contiguous().view( -1, args.nChannel ) outputHP = output.reshape( (im.shape[0], im.shape[1], args.nChannel) ) HPy = outputHP[1:, :, :] - outputHP[0:-1, :, :] HPz = outputHP[:, 1:, :] - outputHP[:, 0:-1, :] lhpy = loss_hpy(HPy,HPy_target) lhpz = loss_hpz(HPz,HPz_target) ignore, target = torch.max(output, 1) im_target = target.data.cpu().numpy() nLabels = len(np.unique(im_target)) if args.visualize: im_target_rgb = np.array([label_colours[ c % args.nChannel ] for c in im_target]) im_target_rgb = im_target_rgb.reshape( im.shape ).astype( np.uint8 ) cv2.imshow( "output", im_target_rgb ) cv2.waitKey(10) # loss if args.scribble: a = output[ inds_sim ] b = target[ inds_sim ] loss = args.stepsize_sim * loss_fn(output[ inds_sim ], target[ inds_sim ].long()) + args.stepsize_scr * loss_fn_scr(output[ inds_scr ], target_scr[ inds_scr ].long()) + args.stepsize_con * (lhpy + lhpz) else: loss = args.stepsize_sim * loss_fn(output, target) + args.stepsize_con * (lhpy + lhpz) loss.backward() optimizer.step() print (batch_idx, '/', args.maxIter, '|', ' label num :', nLabels, ' | loss :', loss.item()) if nLabels <= args.minLabels: print ("nLabels", nLabels, "reached minLabels", args.minLabels, ".") break # save output image if not args.visualize: output = model( data )[ 0 ] output = output.permute( 1, 2, 0 ).contiguous().view( -1, args.nChannel ) ignore, target = torch.max(output, 1) im_target = target.data.cpu().numpy() im_target_rgb = np.array([label_colours[c % args.nChannel] for c in im_target]) im_target_rgb = im_target_rgb.reshape(im.shape).astype( np.uint8 ) cv2.imwrite(f'{outRGBRoot}/{baseName}', im_target_rgb) cv2.imwrite(f'{outRoot}/{baseName}', im_target.reshape(512, 512))

参考

​​https://arxiv.org/abs/2007.09990​​

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:牛津大学|“不变信息聚类” :Invariant Information Clustering
下一篇:基于检测框的遥感场景识别教程
相关文章

 发表评论

暂时没有评论,来抢沙发吧~