0%

Tensor中的索引和筛选

Tensor中的索引和筛选

  • 索引:根据下标选出元素,Tensor类型为int

  • 筛选:根据True/False筛选元素,Tensor类型为Bool

    索引保持原维度不变,筛选可能会使维度变小

筛选

1
2
3
4
5
6
import torch
a = torch.randn((3,5))
b = torch.randint(0,3,(5,)).bool()
print(a)
print(b)
a[:,b]
1
2
3
4
5
6
7
tensor([[ 1.2588,  0.1620, -0.3095,  2.0669, -0.3158],
[-1.6065, -0.7930, -0.5658, 1.8601, -0.2592],
[ 1.3226, 0.3040, -1.0924, 0.7583, -1.5444]])
tensor([ True, False, False, True, True])
tensor([[ 1.2588, 2.0669, -0.3158],
[-1.6065, 1.8601, -0.2592],
[ 1.3226, 0.7583, -1.5444]])

可以看出这是在shape[1]上进行了筛选,False的元素会原地删除

索引

1
2
3
4
5
6
import torch
a = torch.randn((3,5))
b = torch.randint(0,3,(5,))
print(a)
print(b)
a[:,b]
1
2
3
4
5
6
7
tensor([[-1.0592,  0.9445, -0.6265,  0.2607, -0.2166],
[ 1.8283, 0.9678, 0.6175, 2.0904, 0.2356],
[ 0.1864, 1.0110, 0.0425, -1.3611, 1.1043]])
tensor([1, 1, 1, 2, 1])
tensor([[ 0.9445, 0.9445, 0.9445, -0.6265, 0.9445],
[ 0.9678, 0.9678, 0.9678, 0.6175, 0.9678],
[ 1.0110, 1.0110, 1.0110, 0.0425, 1.0110]])

可以看出这是在shape[1]上进行了索引,原维度不变,根据下标取出对应的元素