tf.gather_nd 实例

网友投稿 721 2022-09-07

tf.gather_nd 实例

tf.gather_nd 实例

import tensorflow as tfvalue = [[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]]]init = tf.constant_initializer(value)input = tf.get_variable('input', shape=[2,2,3], initializer=init)value = [[[0,1],[1,0]]]init = tf.constant_initializer(value)index = tf.get_variable('index', shape=[2,1,2], initializer=init, dtype=tf.int32)value = [[0,1],[1,0]]init = tf.constant_initializer(value)index2 = tf.get_variable('index2', shape=[2,2], initializer=init, dtype=tf.int32)value = [0,1]init = tf.constant_initializer(value)index3 = tf.get_variable('index3', shape=[2], initializer=init, dtype=tf.int32)value = [0,1,1]init = tf.constant_initializer(value)index4 = tf.get_variable('index4', shape=[3], initializer=init, dtype=tf.int32)result = tf.gather_nd(input,index)result2 = tf.gather_nd(input,index2)result3 = tf.gather_nd(input,index3)result4 = tf.gather_nd(input,index4)sess = tf.Session()sess.run(tf.global_variables_initializer())print(sess.run(result))print()print(sess.run(result2))print()print(sess.run(result3))print()print(sess.run(result4))

打印结果:

[[[11. 22. 33.]] [[ 4. 5. 6.]]]

[[11. 22. 33.] [ 4. 5. 6.]]

[11. 22. 33.]

22.0

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:Python入门教程 超详细1小时学会Python(python菜鸟教程)
下一篇:一种输入[batch, seq_len1, hidden_dim]输出[batch, seq_len2, hidden_dim]的self-attention的pytorch实现
相关文章

 发表评论

暂时没有评论,来抢沙发吧~