Caffe实战之Python接口系列(八)Siamese Network Tutoria

引言

记录学习官网的例程中的一些重要语句,遇到的问题等,内容分散,建议顺序查看。
主要是调用Caffe的Python接口
源文件就在{caffe_root}/examples中(目录下面的中文标题也附有链接),安装sudo pip install jupyter打开即可运行,初学者最好是放在它指定的目录,如,否则要改很多路径。
注:eaxmples是用jupyter notebook写的,部分Cell中出现了一些特殊的用法:
1. 感叹号‘!’:用于执行系统命令,如 !pwd
2. 百分号‘%’:用法太多,如 %matplotlib inline 显示绘图窗口 详见Jupyter Notebook Viewer

目录

孪生神经网络

1. 配置

  • 导入需要的模块

    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    
    # Make sure that caffe is on the python path:
    
    caffe_root = '../../'  # this file is expected to be in {caffe_root}/examples/siamese
    import sys
    sys.path.insert(0, caffe_root + 'python')
    
    import caffe

2. 加载训练网络

  • 加载模型定义和权重,并设置为CPU模式

    MODEL_FILE = 'mnist_siamese.prototxt'
    
    # decrease if you want to preview during training
    
    PRETRAINED_FILE = 'mnist_siamese_iter_50000.caffemodel' 
    caffe.set_mode_cpu()
    net = caffe.Net(MODEL_FILE, PRETRAINED_FILE, caffe.TEST)
  • 注1:这个模型与LeNet模型几乎完全相同,唯一的区别是用一个产生二维向量的线性”feature”层替换了产生10个数字类别概率的顶层。
  • 注2:预训练模型要通过训练产生,详情请看{caffe_root}/examples/siamese/readme.md

3. 加载一些MNIST的测试数据

TEST_DATA_FILE = '../../data/mnist/t10k-images-idx3-ubyte'
TEST_LABEL_FILE = '../../data/mnist/t10k-labels-idx1-ubyte'
n = 10000

with open(TEST_DATA_FILE, 'rb') as f:
    f.read(16) # skip the header
    raw_data = np.fromstring(f.read(n * 28*28), dtype=np.uint8)

with open(TEST_LABEL_FILE, 'rb') as f:
    f.read(8) # skip the header
    labels = np.fromstring(f.read(n), dtype=np.uint8)

生成孪生特征

# reshape and preprocess
caffe_in = raw_data.reshape(n, 1, 28, 28) * 0.00390625 # manually scale data instead of using `caffe.io.Transformer`
out = net.forward_all(data=caffe_in)

可视化学习到的Siamese embedding

  • 注:自己也是第一次了解孪生网络对’embedding’这个词不知道如何翻译,附上我查到的一篇博客,能让你对孪生网络有个大概映像。Siamese network 孪生神经网络–一个简单神奇的结构 - 简书

    feat = out['feat']
    f = plt.figure(figsize=(16,9))
    c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff', 
         '#ff00ff', '#990000', '#999900', '#009900', '#009999']
    for i in range(10):
        plt.plot(feat[labels==i,0].flatten(), feat[labels==i,1].flatten(), '.', c=c[i])
    plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
    plt.grid()
    plt.show()
  • 上述代码主要是将孪生网络顶层输出的两个值画成坐标点的形式,对输入的10000个数据按标签显示不同的颜色。从图中可知聚类效果越好,说明网络判断相似性的效果越好。

上一篇:Caffe实战之Python接口系列(七)R-CNN detection

下一篇:Caffe实战之Python接口系列 终章总结

猜你喜欢

转载自blog.csdn.net/qq_38156052/article/details/80992540