最近在做实验,设计了一个中间层,有两个输入,一个是前面编码器的输出encoder_output,形状为(None,z_dim)的tensor,另外一个是gumbel_softmax的温度temp,形状为(None,1)的tensor。
def make_gumbel_layer(n_class):
encoder_output = tf.keras.Input(shape=(z_dim,))
temp = tf.keras.Input(shape=(1,))
x = tf.keras.layers.Dense(n_class)(encoder_output)
gumbel_label = gumbel_softmax(x, temp)
model = tf.keras.Model(inputs=[encoder_output, temp], outputs=gumbel_label)
return model
调用的代码为
with tf.GradientTape() as ae_tape:
encoder_output = encoder(batch_x, training=True)
gumbel_label = gumbel_layer([encoder_output, temp])
decoder_output = decoder([encoder_output, gumbel_label], training=True)
结果报warning。
尝试在调用的代码中输出了temp的数据类型
with tf.GradientTape() as ae_tape:
encoder_output = encoder(batch_x, training=True)
print(type(temp))
gumbel_label = gumbel_layer([encoder_output, temp])
decoder_output = decoder([encoder_output, gumbel_label], training=True)
输出结果表明temp是一个float object,不是tensor。可以想到,temp被自动转换为形如()的tensor,即tensor的标量。而模型定义输入是(None, 1)。所以报warning。
所以只需要把temp 转换为 形如(1,1)的tensor即可。将调用的代码改为
with tf.GradientTape() as ae_tape:
encoder_output = encoder(batch_x, training=True)
gumbel_label = gumbel_layer([encoder_output, tf.reshape(temp,[1,1])])
decoder_output = decoder([encoder_output, gumbel_label], training=True)
文章评论