tf.constant_initializer

初始化器,它生成具有常量值的张量。得到的张量填充了类型为dtype的值,由参数值指定,参数值遵循新张量的期望形状。参数值可以是常量值,也可以是类型为dtype的值列表。如果value是一个列表,那么列表的长度必须小于或等于由张量的期望形状所暗示的元素的数量。如果值中的元素总数小于张量形状所需的元素数,则值中的最后一个元素将用于填充剩余的元素。如果值中元素的总数大于张量形状所需元素的总数,初始化器将产生一个ValueError。

参数:

  • value: Python标量、值列表或元组,或n维numpy数组。初始化变量的所有元素将在value参数中设置为对应的值。
  • dtype: 数据类型。
  • verify_shape:布尔值,用于验证值的形状。如果为真,如果值的形状与初始化张量的形状不兼容,初始化器将抛出错误。

Raises:

  • TypeError: If the input value is not one of the expected types.

示例:下面的示例可以使用numpy重写。ndarray代替了值列表,甚至重新构造了值列表,如值列表初始化下面的两行注释所示。

import numpy as np
import tensorflow as tf
value = [0, 1, 2, 3, 4, 5, 6, 7]
# value = np.array(value)
# value = value.reshape([2, 4])
init = tf.constant_initializer(value)

print('fitting shape:')
with tf.Session():
    x = tf.get_variable('x', shape=[2, 4], initializer=init)
    x.initializer.run()
    print(x.eval())
-------------------
fitting shape:
[[ 0.  1.  2.  3.]
[ 4.  5.  6.  7.]]
-------------------




print('larger shape:')
with tf.Session():
   x = tf.get_variable('x', shape=[3, 4], initializer=init)
   x.initializer.run()
   print(x.eval())
-------------------
larger shape:
[[ 0.  1.  2.  3.]
[ 4.  5.  6.  7.]
[ 7.  7.  7.  7.]]
-------------------




print('smaller shape:')
with tf.Session():
    x = tf.get_variable('x', shape=[2, 3], initializer=init)
-------------------------------------------------------------------------------------
* <b>`ValueError`</b>: Too many elements provided. Needed at most 6, but received 8
-------------------------------------------------------------------------------------




print('shape verification:')
init_verify = tf.constant_initializer(value, verify_shape=True)
with tf.Session():
    x = tf.get_variable('x', shape=[3, 4], initializer=init_verify)
----------------------------------------------------------------
* <b>`TypeError`</b>: Expected Tensor's shape: (3, 4), got (8,).
----------------------------------------------------------------

Method:

__init__

__init__(
    value=0,
    dtype=tf.float32,
    verify_shape=False
)

__call__

__call__(
    shape,
    dtype=None,
    partition_info=None,
    verify_shape=None
)

from_config

from_config(
    cls,
    config
)

从配置字典实例化初始化器。例子:

initializer = RandomUniform(-1, 1)
config = initializer.get_config()
initializer = RandomUniform.from_config(config)

参数

  • config: 一个Python字典。它通常是get_config的输出。

返回:

  • 一个初始化后的实例。

get_config

get_config()

参考网址: https://tensorflow.google.cn/versions/r1.9/api_docs/python/tf/constant_initializer?hl=en

猜你喜欢

转载自blog.csdn.net/weixin_36670529/article/details/91415801
今日推荐