可视化滤波器

下载
https://github.com/raghakot/keras-vis
python setup.py install 这样才能装最新版,不要pip


################可视化卷积核代码
model.summary()
from vis.utils import utils
from keras import activations
from vis.visualization import visualize_activation
import matplotlib.pyplot as plt
from vis.input_modifiers import Jitter
import cv2
# layer_idx=utils.find_layer_idx(model=model,layer_name='stage2_unit4_conv2')
# img=visualize_activation(model=model,layer_idx=layer_idx,filter_indices=1,verbose=False,input_modifiers=[Jitter(16)])
# plt.imshow(img)
# plt.show()
# print('准备输出')

layer_name = 'stage2_unit4_conv2'
size = 64
margin = 8
num= 8
layer_idx=utils.find_layer_idx(model=model,layer_name='stage2_unit4_conv2')
results = np.zeros((num * size + (num-1) * margin, num * size + (num-1)* margin, 3))
for i in range(num):
for j in range(num):
filter_img = visualize_activation(model=model, layer_idx=layer_idx, filter_indices= i + (j * num), verbose=False,
input_modifiers=[Jitter(16)])
filter_img=filter_img/255.0
horizontal_start = i* size + i * margin
horizontal_end = horizontal_start + size
vertical_start = j * size + j * margin
vertical_end = vertical_start + size
rsize = (int(64), int(64))
filter_img_r = cv2.resize(filter_img, rsize, interpolation=cv2.INTER_NEAREST)
results[horizontal_start: horizontal_end,
vertical_start: vertical_end, :] = filter_img_r
#plt.figure(figsize=(100, 100))
plt.imshow(results)
plt.show()

猜你喜欢

转载自www.cnblogs.com/love6tao/p/12739538.html
今日推荐