keras使用plot_model绘制网络模型图

from keras.utils import plot_model
plot_model(model, './model.bmp', show_shapes=True)

使用keras中的plot_model模块可以绘制网络模型图,但是可能报pydot缺失的错。

pip安装完又报另一个错误`pydot` failed to call GraphViz.

根据提示到相关网站下载对应系统的安装包吧,这里下的是window版本的msi安装包

安装完后发现报错还是没有解决,仔细检查报错,会发现这里是dot;

此时应该去site-packages路径下的pydot.py中将self.prog = 'dot'修改为self.prog = 'dot.exe'(大概1710行数)

另外还得为刚刚安装的GraphViz添加环境变量,可以在系统设置,也可以代码中添加

import os
os.environ["PATH"] += ";D:/Program/Graphviz2.38/bin/"

再次执行如下完整代码

from __future__ import absolute_import, division, print_function
import os
import tensorflow as tf
from tensorflow import keras
from keras.utils import plot_model

print('tf version: {}'.format(tf.__version__))


# Returns a short sequential model
def create_model():
    model = tf.keras.models.Sequential([
        keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10, activation=tf.nn.softmax)
    ])

    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.sparse_categorical_crossentropy,
                  metrics=['accuracy'])

    return model


# Create a basic model instance
model = create_model()
os.environ["PATH"] += ";D:/Program/Graphviz2.38/bin/"
plot_model(model, './model.bmp', show_shapes=True)

就可以得到网络的模型图

原创文章 26 获赞 20 访问量 3万+

猜你喜欢

转载自blog.csdn.net/HJXINKKL/article/details/89483033