use_pyramid
原始文件为 .py 代码,本文是转换后的 Markdown 文件。
```.py
'''
用直方图匹配和 Steerable 金字塔进行纹理合成,参考文章:https://zhuanlan.zhihu.com/p/67776577
参考代码:https://nbviewer.org/github/yourwanghao/CMUComputationalPhotography/blob/master/class18/Notebook18.ipynb
使用了更清晰的代码结构,使用更方便的 pyrtools 库替代了上述代码用的 pyTorchSteerablePyramid 库
'''
import cv2
import numpy as np
直方图匹配,把 src 的直方图匹配到 template 的直方图
def matchHist(src, template):
# 输入必须是单通道图(彩色图需要分别处理)
assert(len(src.shape) == 2 and len(template.shape) == 2)
rows, cols = src.shape
src = src.ravel()
template = template.ravel()
# get the set of unique pixel values and their corresponding indices and counts
# s_values: 各个出现过的像素值,如 [2, 3, 4, 5]
# s_counts: 各个对应像素值出现次数,还是以上面的 [2,3,4,5] 为例,如像素值 2 出现次数是 1000,那么就是 [1000, ...]
# bin_idx: 各个像素对应的索引,比如第一个像素值是 3,而 3 在 s_values 的索引是 1,那么就是 [1, ...]
s_values, bin_idx, s_counts = np.unique(src, return_inverse=True, return_counts=True)
t_values, t_counts = np.unique(template, return_counts=True)
s_quantiles = np.cumsum(s_counts).astype(np.float64)
s_quantiles /= s_quantiles[-1]
t_quantiles = np.cumsum(t_counts).astype(np.float64)
t_quantiles /= t_quantiles[-1]
interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)
return interp_t_values[bin_idx].reshape(rows, cols)
把 noise 的直方图匹配到 texture 的直方图
def matchTexture(noise, texture, layernum=6, iternum=5):
import pyrtools
# 构造 Steerable Pyramid
# order: 控制方向数目,需要注意 pyrtools 中传入 order 参数要减一,实际构造是 order+1 个,一般用 steerable 是四个方向
order = 4
pyrbuild = lambda x: pyrtools.pyramids.SteerablePyramidSpace(x, height=layernum, order=order-1)
# 纹理金字塔
pyrTexture = pyrbuild(texture)
# 目标输出
outputImg = matchHist(noise, texture)
outputImgs = [outputImg]
for _ in range(iternum):
pyrNoise = pyrbuild(outputImg)
# 每一层进行直方图匹配
for layeri in range(0, layernum):
for bandi in range(order):
texture_band = pyrTexture.pyr_coeffs[(layeri, bandi)]
noise_band = pyrNoise.pyr_coeffs[(layeri, bandi)]
print('layer', layeri, 'band', bandi, 'shape', noise_band.shape)
pyrNoise.pyr_coeffs[(layeri, bandi)] = matchHist(noise_band, texture_band)
# 别忘了还有原始的细节图片,和最后最后一层的低分辨率图片
for layername in ['residual_lowpass', 'residual_highpass']:
texture_band = pyrTexture.pyr_coeffs[layername]
noise_band = pyrNoise.pyr_coeffs[layername]
# 打印就理解了
print('layer', layername, 'shape', noise_band.shape)
pyrNoise.pyr_coeffs[layername] = matchHist(noise_band, texture_band)
outputImg = pyrNoise.recon_pyr()
outputImg = matchHist(outputImg, texture)
outputImgs.append(outputImg)
return outputImgs
读入纹理片段
texture = cv2.imread('./data/texture.png')[0:680, 0:680]
texture = cv2.cvtColor(texture, cv2.COLOR_BGR2YCrCb) # 为了颜色混叠,转到另一个颜色空间去做..
rows, cols = texture.shape[0:2]
构造噪声图像
noise = np.random.randint(0, 256, size=(rows,cols, 3)).astype(np.uint8)
cv2.imwrite('./result/use_pyramid/noise.png', noise)
noise = cv2.cvtColor(noise, cv2.COLOR_BGR2YCrCb)
results = [[], [], []]
for i in range(3):
results[i] = matchTexture(noise[:,:,i], texture[:,:,i], layernum=6, iternum=5)
import os
os.makedirs("./result/use_pyramid", exist_ok=True)
for i in range(len(results[0])):
nowimg = np.zeros((rows, cols, 3), dtype=np.uint8)
nowimg[:,:,0] = results[0][i].astype(np.uint8)
nowimg[:,:,1] = results[1][i].astype(np.uint8)
nowimg[:,:,2] = results[2][i].astype(np.uint8)
nowimg = cv2.cvtColor(nowimg, cv2.COLOR_YCrCb2BGR)
cv2.imwrite(f"./result/use_pyramid/result{i}.jpg", nowimg)
noise = cv2.cvtColor(noise, cv2.COLOR_YCrCb2BGR)
texture = cv2.cvtColor(texture, cv2.COLOR_YCrCb2BGR)
hist_noise = cv2.calcHist([noise], [0], None, [256], [0, 256])
hist_texture = cv2.calcHist([texture], [0], None, [256], [0, 256])
hist_output = cv2.calcHist([nowimg], [0], None, [256], [0, 256])
import matplotlib.pyplot as plt
plt.figure(figsize=(15,10))
plt.subplot(2,3,1)
plt.title("noise")
plt.plot(hist_noise)
plt.subplot(2,3,2)
plt.title("texture")
plt.plot(hist_texture)
plt.subplot(2,3,3)
plt.title("output")
plt.plot(hist_output)
plt.subplot(2,3,4)
plt.title("noise")
plt.imshow(noise[..., ::-1])
plt.subplot(2,3,5)
plt.title("texture")
plt.imshow(texture[..., ::-1])
plt.subplot(2,3,6)
plt.title("output")
plt.imshow(nowimg[..., ::-1])
plt.show()```