关于tf.distributions的那些事儿

引子

       在学习各类Machine Learning方法时,免不了要与“分布”打交道。我们有时候需要计算某个分布的熵,有时候需要计算两个分布之间的交叉熵或KL散度。当然,这可以通过使用Numpy中的numpy.random.normal之类的函数来实现,但是我们更希望能够按照TensorFlow计算图的形式来实现,这样的话,可以更好地利用TensorFlow的一些优势(如一次性计算,共享计算结果等)。

简介

       tf.distributions是TensorFlow提供的核心组件之一,用于实现一些常见的概率分布,并给出了一系列的辅助计算函数。首先,该组件中有Distribution基类、RegisterKL类、ReparameterizationType类。其中RegisterKL类是一个注册KL散度实现的装饰器,也即可以为某个分布添加KL散度的计算功能。此外,该组件还实现了以下分布:

  • Bernoulli Distribution;
  • Beta Distribution;
  • Categorical Distribution;
  • Dirichlet Distribution;
  • Dirichlet-Multinomial Distribution;
  • Exponential Distribution;
  • Gamma Distribution;
  • Laplace Distribution;
  • Multinomial Distribution;
  • Normal Distribution;
  • StudentT Distribution;
  • Uniform Distribution.

       下面我们以Normal Distribution为例来进行介绍。

tf.distributions.Normal

       Normal类型定义在tensorflow/python/ops/distributions/normal.py文件中。其__init__函数定义如下:

__init__(
    loc,
    scale,
    validate_args=False,
    allow_nan_stats=True,
    name='Normal'
)

       其中loc为高斯分布的均值 μ \mu ,scale为标准差 σ \sigma
       在Normal类中,有如下properties:

  • allow_nan_stats
  • batch_shape
  • dtype
  • event_shape
  • loc
  • name
  • parameters
  • reparameterization_type
  • scale
  • validate_args
           关于这些性质的解释就不赘述了。下面列出Normal类中给出的一些方法(列出来只是为了能够一目了然):
  • batch_shape_tensor (name=‘batch_shape_tensor’)
  • cdf (value, name=‘cdf’)
  • copy (**override_parameters_kwargs)
  • covariance (name=‘covariance’)
  • cross_entropy (other, name=‘cross_entropy’)
  • entropy (name=‘entropy’)
  • event_shape_tensor (name=‘event_shape_tensor’)
  • is_scalar_batch (name=‘is_scalar_batch’)
  • is_scalar_event (name=‘is_scalar_event’)
  • kl_divergence (other, name=‘kl_divergence’)
  • log_cdf (value, name=‘log_cdf’)
  • log_prob (value, name=‘log_prob’)
  • log_survival_function (value, name=‘log_survival_function’)
  • mean (name=‘mean’)
  • mode (name=‘mode’)
  • param_shapes (cls, sample_shape, name=‘DistributionParamShapes’)
  • param_static_shapes (cls, sample_shape)
  • prob (value, name=‘prob’)
  • quantile (value, name=‘quantile’)
  • sample (sample_shape=(), seed=None, name=‘sample’)
  • stddev (name=‘stddev’)
  • survival_function (value, name=‘survival_function’)
  • variance (name=‘variance’)
           其中有计算熵的entropy方法,计算交叉熵的cross_entropy方法,计算KL散度(相对熵)的kl_divergence方法,这些方法为我们提供了极大的便利。

尾声

       本文对tf.distributions进行了极简的介绍,大家如果对此有兴趣的话可以直接在TensorFlow官网查看,具体见:tf.distributions
       大家周五快乐~

猜你喜欢

转载自blog.csdn.net/u013745804/article/details/83375515