内容简介:文章目录这节我们介绍利用一个预训练模型清除图像中雾霾,使图像更清晰。
文章目录
这节我们介绍利用一个预训练模型清除图像中雾霾,使图像更清晰。
26.1 导入需要的模块
import torch import torch.nn as nn import torchvision import torch.backends.cudnn as cudnn import torch.optim import os import numpy as np from torchvision import transforms from PIL import Image import glob
26.2 查看原来的图像
import matplotlib.pyplot as plt
from matplotlib.image import imread
%matplotlib inline
img=imread('./clean_photo/test_images/shanghai01.jpg')
plt.imshow(img)
plt.show
26.3 定义一个神经网络
这个神经网络主要由卷积层构成,该网络将构建在预训练模型之上。
#定义一个神经网络 class model(nn.Module): def __init__(self): super(model, self).__init__() self.relu = nn.ReLU(inplace=True) self.e_conv1 = nn.Conv2d(3,3,1,1,0,bias=True) self.e_conv2 = nn.Conv2d(3,3,3,1,1,bias=True) self.e_conv3 = nn.Conv2d(6,3,5,1,2,bias=True) self.e_conv4 = nn.Conv2d(6,3,7,1,3,bias=True) self.e_conv5 = nn.Conv2d(12,3,3,1,1,bias=True) def forward(self, x): source = [] source.append(x) x1 = self.relu(self.e_conv1(x)) x2 = self.relu(self.e_conv2(x1)) concat1 = torch.cat((x1,x2), 1) x3 = self.relu(self.e_conv3(concat1)) concat2 = torch.cat((x2, x3), 1) x4 = self.relu(self.e_conv4(concat2)) concat3 = torch.cat((x1,x2,x3,x4),1) x5 = self.relu(self.e_conv5(concat3)) clean_image = self.relu((x5 * x) - x5 + 1) return clean_image
26.4 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = model().to(device)
def cl_image(image_path):
data = Image.open(image_path)
data = (np.asarray(data)/255.0)
data = torch.from_numpy(data).float()
data = data.permute(2,0,1)
data = data.to(device).unsqueeze(0)
#装载预训练模型
net.load_state_dict(torch.load('clean_photo/dehazer.pth'))
clean_image = net.forward(data)
torchvision.utils.save_image(torch.cat((data, clean_image),0), "clean_photo/results/" + image_path.split("/")[-1])
if __name__ == '__main__':
test_list = glob.glob("clean_photo/test_images/*")
for image in test_list:
cl_image(image)
print(image, "done!")
clean_photo/test_images/shanghai02.jpg done!
26.5 查看处理后的图像
处理后的图像与原图像拼接在一起,保存在clean_photo /results目录下。
import matplotlib.pyplot as plt
from matplotlib.image import imread
%matplotlib inline
img=imread('clean_photo/results/shanghai01.jpg')
plt.imshow(img)
plt.show
虽非十分理想,但效果还是比较明显的!
更多内容可参考:
https://github.com/TheFairBear/PyTorch-Image-Dehazing
以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网
猜你喜欢:- 零基础小白快速打造图像识别模型
- 如何使用注意力模型生成图像描述?
- 如何优化你的图像分类模型效果?
- 【图像分割模型】用BRNN做分割—ReSeg
- 调用 TensorFlow 已训练好的模型做图像识别
- 图像翻译——pix2pix模型
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
Imperfect C++中文版
威尔逊 / 荣耀、刘未鹏 / 人民邮电出版社 / 2006-1 / 75.0
汇集实用的C++编程解决方案,C++虽然是一门非凡的语言,但并不完美。Matthew Wilson使用C++十年有余,其间发现C++存在一些固有的限制,需要一些颇具技术性的工作进行弥补。本书不仅指出了C++的缺失,更为你编写健壮、灵活、高效、可维护的代码提供了实用的技术和工具。Wilson向你展示了如何克服C++的复杂性,穿越C++庞大的范式阵列。夺回对代码的控制权,从而获得更理想的结果。一起来看看 《Imperfect C++中文版》 这本书的介绍吧!
HTML 编码/解码
HTML 编码/解码
XML 在线格式化
在线 XML 格式化压缩工具