很多童鞋看到y[:, 3]
或者tf.gather(y, [1,3], axis=1)
这样的部分取值操作的时候都会一脸懵逼,不知道这是在干什么,下面来简单介绍一下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# y是一个三维的张量 y=tf.constant( [[[ 1, 2], [ 3, 4], [ 5, 6], [ 7, 8]], [[ 9, 10], [11, 12], [13, 14], [15, 16]], [[17, 18], [19, 20], [21, 22], [23, 24]]]) # y的shape=(3, 4, 2) # ---------------------------------------------------------- sess.run(y[:, 3]) # 单独一个冒号表示这一维取全部,[:,3]表示第一维取全部,第二维取下标为3的元素,(第三维取全部) ''' 结果为 [[ 7, 8], [15, 16], [23, 24]] ''' # 注意区分 y[:, 3] 和 y[:, 3:4],前者因为取单个值会降一个维度 # ---------------------------------------------------------- sess.run(y[:, 3:4]) ''' 结果为 [[[ 7, 8]], [[15, 16]], [[23, 24]]] ''' # ---------------------------------------------------------- sess.run(y[:, 3, 0]) # 第一维取全部,第二维取下标为3的元素,第三维取下标为0的元素 ''' 结果为 [ 7, 15, 23] ''' # ---------------------------------------------------------- tf.gather(y, [1,3], axis=1) # 表示在第2维上(axis默认为0,表示第1维)取下标为1和3的元素,取出后的shape=(3, 2, 2) ''' 结果为 [[[ 3, 4], [ 7, 8]], [[11, 12], [15, 16]], [[19, 20], [23, 24]]] ''' |