TensorBoard:模型解析流程

当我们需要查看模型文件如.pb, .meta等的结构的时候,需要先根据model生成event文件。在生成event文件之后,通过命令行执行tensorboard --logdir=[event dir]即可启动TensorBoard服务,其中命令中中括号内表示event文件所在目录。那么我们来看看,event文件是怎么一步一步呈现出来的。
首先来看一下整个event文件解析流程的时序图:
event解析流程
上图是服务器启动之前所做的解析工作,下面我们就跟着代码,看看每一步都是怎么进行的。
一切还得从文章最开始得那条命令开始说起。既然存在一个tensorboard命令,那么就一定可以找到一个以该命令命名的文件,且该文件所在的路径一定在操作系统的环境变量中。直接使用which tensorboard 可以定位到文件所在目录,在我电脑上,它位于/home/chou/SNPE/venv/bin路径下。

(venv) chou@chou-vb:~/$ which tensorboard
/home/chou/SNPE/venv/bin/tensorboard

打开该文件,很简单,只有几行:

from tensorboard.main import run_main

if __name__ == '__main__':
    sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0])
    sys.exit(run_main())

从第一行可以看出。真正的入口是tensorboard/main.py中的run_main()函数。我们找到Github下载下来的源代码,开main.py文件并定位到run_main()函数。

def run_main():
  """Initializes flags and calls main()."""
  program.setup_environment()
  
  # 生成TensorBoard实例
  tensorboard = program.TensorBoard(default.get_plugins(),
                                    program.get_default_assets_zip_provider())
  try:
    from absl import app
    # Import this to check that app.run() will accept the flags_parser argument.
    from absl.flags import argparse_flags
    app.run(tensorboard.main, flags_parser=tensorboard.configure)
    raise AssertionError("absl.app.run() shouldn't return")
  except ImportError:
    pass
  tensorboard.configure(sys.argv)
  
  # 调用TensorBoard的main函数
  sys.exit(tensorboard.main())

可以看到,run_main()的主要工作是进行了一些设置,并最终调用了TensorBoard的main函数。

  def main(self, ignored_argv=('',)):

    ....
    
    try:
      server = self._make_server()
      sys.stderr.write('TensorBoard %s at %s (Press CTRL+C to quit)\n' %
                       (version.VERSION, server.get_url()))
      sys.stderr.flush()
      server.serve_forever()
      return 0
    except TensorBoardServerException as e:
        ....

为了简洁,省略了一些关系不大的代码,使用....表示。从函数名中可以看到,main执行了两个操作:

  1. 创建服务器
  2. 启动创建的服务
    本文我们先看创建服务器部分,因为两部分都是很复杂的内容。跳转进入_make_server()函数:
  def _make_server(self):
    """Constructs the TensorBoard WSGI app and instantiates the server."""
    app = application.standard_tensorboard_wsgi(self.flags,
                                                self.plugin_loaders,
                                                self.assets_zip_provider)
    return self.server_class(app, self.flags)

进来也很简单,从名字能猜个大概,调用了一个启动标准WSGI程序的函数,接着跟着跳吧:

def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider):
    .....
  multiplexer = event_multiplexer.EventMultiplexer(
      size_guidance=DEFAULT_SIZE_GUIDANCE,
      tensor_size_guidance=tensor_size_guidance_from_flags(flags),
      purge_orphaned_data=flags.purge_orphaned_data,
      max_reload_threads=flags.max_reload_threads)
  loading_multiplexer = multiplexer
  ....
 context = base_plugin.TBContext(
      db_module=db_module,
      db_connection_provider=db_connection_provider,
      db_uri=db_uri,
      flags=flags,
      logdir=flags.logdir,
      multiplexer=multiplexer,
      assets_zip_provider=assets_zip_provider,
      plugin_name_to_instance=plugin_name_to_instance,
      window_title=flags.window_title)
  plugins = []
  for loader in plugin_loaders:
    plugin = loader.load(context)
    if plugin is None:
      continue
    plugins.append(plugin)
    plugin_name_to_instance[plugin.plugin_name] = plugin
  return TensorBoardWSGIApp(flags.logdir, plugins, loading_multiplexer,
                            reload_interval, flags.path_prefix,
                            reload_task)

def TensorBoardWSGIApp(logdir, plugins, multiplexer, reload_interval,
                       path_prefix='', reload_task='auto'):
  path_to_run = parse_event_files_spec(logdir)
  if reload_interval >= 0:
    # We either reload the multiplexer once when TensorBoard starts up, or we
    # continuously reload the multiplexer.
    start_reloading_multiplexer(multiplexer, path_to_run, reload_interval,
                                reload_task)
  return TensorBoardWSGI(plugins, path_prefix)
  

同样,我们省略了一些代码。在以上函数的最后,返回来一个TensorBoardWSGI的实例,事情到这里,创建服务器的部分貌似已经结束了,但是,我们并没有看到解析event文件的代码,说明还遗漏了什么。仔细查看代码,发现有个start_reloading_multiplexer函数。并且,该函数的一个参数特别眼熟,logdir,没错,就是我们命令行传递进来的参数,也就event是文件所在目录。最终,logdir转换到了path_to_run。那好,跟进去看看。

def start_reloading_multiplexer(multiplexer, path_to_run, load_interval,
                                reload_task):
    .....
  def _reload():
    while True:
      start = time.time()
      tf.logging.info('TensorBoard reload process beginning')
      for path, name in six.iteritems(path_to_run):
        multiplexer.AddRunsFromDirectory(path, name)
      tf.logging.info('TensorBoard reload process: Reload the whole Multiplexer')
      multiplexer.Reload()
      duration = time.time() - start
      tf.logging.info('TensorBoard done reloading. Load took %0.3f secs', duration)
      if load_interval == 0:
        # Only load the multiplexer once. Do not continuously reload.
        break
      time.sleep(load_interval)

  if reload_task == 'process':
    tf.logging.info('Launching reload in a child process')
    import multiprocessing
    process = multiprocessing.Process(target=_reload, name='Reloader')
    # Best-effort cleanup; on exit, the main TB parent process will attempt to
    # kill all its daemonic children.
    process.daemon = True
    process.start()
  elif reload_task in ('thread', 'auto'):
    tf.logging.info('Launching reload in a daemon thread')
    thread = threading.Thread(target=_reload, name='Reloader')
    # Make this a daemon thread, which won't block TB from exiting.
    thread.daemon = True
    thread.start()
  elif reload_task == 'blocking':
    if load_interval != 0:
      raise ValueError('blocking reload only allowed with load_interval=0')
    _reload()
  else:
    raise ValueError('unrecognized reload_task: %s' % reload_task)

可以看到,其中定义了一个内函数,根据实际参数,会有创建新的进程,创建线程,或者直接调用集中方式调用,但不管怎么调用,函数还是那个函数,所以我们目前对调用方式不用太关系,知道调用了它就好。以上比较重要的两行:

multiplexer.AddRunsFromDirectory(path, name)
multiplexer.Reload()

其中第一行遍历了我们传进来的目录,找了所有该目录下的event文件,之后就是使用Relaod()进行加载解析。

  def Reload(self):
    """Call `Reload` on every `EventAccumulator`."""
    tf.logging.info('Beginning EventMultiplexer.Reload()')
    self._reload_called = True
    # Build a list so we're safe even if the list of accumulators is modified
    # even while we're reloading.
    with self._accumulators_mutex:
      items = list(self._accumulators.items())

    names_to_delete = set()
    for name, accumulator in items:
      try:
        accumulator.Reload()
      except (OSError, IOError) as e:
        tf.logging.error("Unable to reload accumulator '%s': %s", name, e)
      except directory_watcher.DirectoryDeletedError:
        names_to_delete.add(name)

    with self._accumulators_mutex:
      for name in names_to_delete:
        tf.logging.warning("Deleting accumulator '%s'", name)
        del self._accumulators[name]
    tf.logging.info('Finished with EventMultiplexer.Reload()')
    return self
    
def Reload(self):

    with self._generator_mutex:
      for event in self._generator.Load():
        self._ProcessEvent(event)
    return self

辗转调用到了_ProcessEvent(event),而正是该函数,进行真正的解析工作,他讲根据protobuf协议对event文件进行解析,并将解析的内容保存在accumulator中。
至此,我们已经找到了根源。但是现在还有个疑问,解析后的accumulator存放在哪里呢?答案在items = list(self._accumulators.items())这一句。在TensorBoard中,每个图的显示是根据run进行分组的,也就是说有多少个run就有多少个图。所谓的run,简单点理解可以理解为logdir的子文件夹。在上面提到的AddRunsFromDirectory函数中,它将每个run(子文件夹)的名字与一个accumulator映射起来。这个很重要,因为后续服务器启动后,就是根据每个run的名字找到对应的accumulator,进而获取相对应的图的内容来显示。
我们知道了EventMultiplexer中持有了一组·run,accumulator的映射,那么EventMultiplexer,有保存哪里呢?答案在上面提到的函数standard_tensorboard_wsgi中,在该函数中,EventMultiplexer先保存在一个TBContext中,然后传给了相对应的plugin,这样,最终每个plugin变得到了他能够处理的数据。
剩下的,就是启动服务,并等待收到浏览器发过来的请求啦。


本文首发于个人公众号TensorBoy。如果你觉得内容还不错,欢迎分享并关注我的公众号TensorBoy,扫描下方二维码获取更多精彩原创内容!

公众号二维码

发布了45 篇原创文章 · 获赞 4 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/ZM_Yang/article/details/86241823