【易踩坑】torch.split函数
发布网友
发布时间:2023-07-08 18:51
我来回答
共1个回答
热心网友
时间:2023-10-24 06:45
torch.split(tensor, split_size, dim=)
tensor是要切割的张量,dim表示在哪个维度上面进行切割
a = torch.LongTensor([[1,2,3,4],[2,3,4,5]])
b = torch.cat(torch.split(a, 4, dim=1), dim=0)
print(b)
输出:tensor([[1, 2, 3, 4],
[2, 3, 4, 5]])