tf.gather_nd详解

网友投稿 618 2022-11-16

tf.gather_nd详解

tf.gather_nd详解

其实就是取出对应位置的元素,直接看代码更直观简单

# coding=utf-8# tf 2.0import tensorflow as tfa = tf.constant([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])index_a1 = tf.constant([[0, 2], [0, 4], [2, 2]]) # 随便选几个index_a2 = tf.constant([0, 1]) # 0行1列的元素——2index_a3 = tf.constant([[0], [1]]) # [第0行,第1行]index_a4 = tf.constant([0]) # 第0行print(tf.gather_nd(a, index_a1))print(tf.gather_nd(a, index_a2))print(tf.gather_nd(a, index_a3))print(tf.gather_nd(a, index_a4))

输出

tf.Tensor([ 3 5 13], shape=(3,), dtype=int32)tf.Tensor(2, shape=(), dtype=int32)tf.Tensor([[ 1 2 3 4 5] [ 6 7 8 9 10]], shape=(2, 5), dtype=int32)tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

附录

关于 TF1.0 和 TF2.0 的区别

Tensorflow 2.x 默认支持 Eager Execution,因此不支持 Session。

TF1.x 的 Hello World

import tensorflow as tfmsg = tf.constant('Hello, TensorFlow!')sess = tf.Session()print(sess.run(msg))

TF2.x 的 Hello World

import tensorflow as tfmsg = tf.constant('Hello, TensorFlow!')tf.print(msg)

或者

import tensorflow as tfwith tf.compat.v1.Session() as sess: hello = tf.constant('hello world') print(sess.run(hello))

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:使用FileReader采用的默认编码
下一篇:SQL查询笔记
相关文章

 发表评论

暂时没有评论,来抢沙发吧~