torch.flatten
xxxxxxxxxx
>>> t = torch.tensor([[[1, 2],
3, 4]], [
5, 6], [[
7, 8]]]) [
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
https://pytorch.org/docs/stable/generated/torch.flatten.html