5分鐘玩轉(zhuǎn)PyTorch | 詳解張量的分割與合并
AI因你而升溫,記得加星標(biāo)哦!
↑?關(guān)注 + 星標(biāo)?,每天學(xué)Python新技能
后臺回復(fù)【大禮包】送你Python自學(xué)大禮包
在使用PyTorch時,對張量的分割與合并是不可避免的操作,本節(jié)就帶大家深刻理解張量的分割與合并。
在開始之前,我們先對張量的維度進(jìn)行深入理解:
t2?=?torch.zeros((3,?4))
#?tensor([[0.,?0.,?0.,?0.],
#?????????[0.,?0.,?0.,?0.],
#?????????[0.,?0.,?0.,?0.]])
????????
t2.shape
#?torch.Size([3,?4])
重點(diǎn)理解
我們可以把shape的返回結(jié)果看成一個序列,代表著各張量維度的信息,第一個數(shù)字3代表行,即向量數(shù),第二個數(shù)字4代表列,即每個向量中的標(biāo)量數(shù)。
深入理解:t2是由3個一維張量組成,并且每個一維張量都包含四個元素。
張量的分割
chunk(tensor, chunks, dim)
chunk函數(shù)能夠按照某個維度(dim)對張量進(jìn)行均勻切分(chunks),并且返回結(jié)果是原張量的視圖。
#?創(chuàng)建一個4×3的矩陣
t2?=?torch.arange(12).reshape(4,?3)
t2
#?tensor([[?0,??1,??2],
#????????[?3,??4,??5],
#????????[?6,??7,??8],
#????????[?9,?10,?11]])
張量可均分時
在第0個維度(shape的第一個數(shù)字,代表向量維度)上將t2進(jìn)行4等分:
#?在矩陣中,第一個維度是行,理解為shape的第一個數(shù)
tc?=?torch.chunk(t2,?chunks?=?4,?dim?=?0)
tc
#?(tensor([[0,?1,?2]]),
#??tensor([[3,?4,?5]]),
#??tensor([[6,?7,?8]]),
#??tensor([[?9,?10,?11]]))
根據(jù)結(jié)果可見:
返回結(jié)果是一個元組,不可變
tc[0]?=?torch.tensor([[1,?1,?1]])
#?TypeError:?'tuple'?object?does?not?support?item?assignment
元組中的每個值依然是一個二維張量
tc[0]
#?tensor([[0,?1,?2]])
返回的張量 tc的一個視圖,不是新成了一個對象
#?我們將原張量t2中的數(shù)值進(jìn)行更改
t2[0]?=?torch.tensor([6,?6,?6])
#?再打印分塊后tc的結(jié)果
tc
#?(tensor([[6,?6,?6]]),
#??tensor([[3,?4,?5]]),
#??tensor([[6,?7,?8]]),
#??tensor([[?9,?10,?11]]))
若還不懂視圖概念,點(diǎn)擊這里進(jìn)行學(xué)習(xí)
張量不可均分時
若原張量不能均分時,chunk不會報錯,會返回次一級均分結(jié)果。
#?創(chuàng)建一個4×3的矩陣
t2?=?torch.arange(12).reshape(4,?3)
t2
#?tensor([[?0,??1,??2],
#????????[?3,??4,??5],
#????????[?6,??7,??8],
#????????[?9,?10,?11]])
將4行分為3等份,不可分,就會返回2等分的結(jié)果:
tc?=?torch.chunk(t2,?chunks?=?3,?dim?=?0)
tc
#?(tensor([[0,?2,?2],
#??????????[3,?4,?5]]),?
#??tensor([[?6,??7,??8],
#??????????[?9,?10,?11]]))
將4行分為5等份,不可分,就會返回4等分的結(jié)果:
tc?=?torch.chunk(t2,?chunks?=?5,?dim?=?0)
#?(tensor([[0,?2,?2]]),
#??tensor([[3,?4,?5]]),
#??tensor([[6,?7,?8]]),
#??tensor([[?9,?10,?11]]))
split函數(shù)
split既能進(jìn)行均分,也能進(jìn)行自定義切分。需要注意的是split的返回結(jié)果也是視圖。
#?第二個參數(shù)只輸入一個數(shù)值時表示均分
#?第三個參數(shù)表示切分的維度
torch.split(t2,?2,?dim?=?0)
#?(tensor([[0,?1,?2],
#??????????[3,?4,?5]]),?
#??tensor([[?6,??7,??8],
#??????????[?9,?10,?11]]))
與chunk函數(shù)不同的是,split第二個參數(shù)可以輸入一個序列,表示按照序列數(shù)值等分:
torch.split(t2,?[1,3],?dim?=?0)
#?(tensor([[0,?1,?2]]),?
#??tensor([[?3,??4,??5],
#??????????[?6,??7,??8],
#??????????[?9,?10,?11]]))
當(dāng)?shù)诙€參數(shù)輸入一個序列時,序列的各數(shù)值的和必須等于對應(yīng)維度下形狀分量的取值,即shape對應(yīng)的維度。
例如上述代碼中,是按照第一個維度進(jìn)行切分,而t2總共有4行,因此序列的求和必須等于4,也就是1+3=4,而序列中每個分量的取值,則代表切塊大小。
torch.split(t2,?[1,?1,?2],?0)
#?(tensor([[0,?1,?2]]),?
#??tensor([[3,?4,?5]]),?
#??tensor([[?6,??7,??8],
#?????????[?9,?10,?11]]))
將張量第一個維度(行維度)分為1:1:2。
張量的合并
張量的合并操作類似與列表的追加元素,可以進(jìn)行拼接、也可以堆疊。
這里一定要將dim參數(shù)與shape返回的結(jié)果相對應(yīng)理解。
cat拼接函數(shù)
a?=?torch.zeros(2,?3)
a
#?tensor([[0.,?0.,?0.],
#?????????[0.,?0.,?0.]])
b?=?torch.ones(2,?3)
b
#?tensor([[1.,?1.,?1.],
#?????????[1.,?1.,?1.]])
因為在張量a與b中,shape的第一個位置是代表向量維度,所以當(dāng)dim取0時,就是將向量進(jìn)行合并,向量中的標(biāo)量數(shù)不變:
torch.cat([a,?b],?dim?=?0)
#?tensor([[0.,?0.,?0.],
#?????????[0.,?0.,?0.],
#?????????[1.,?1.,?1.],
#?????????[1.,?1.,?1.]])
當(dāng)dim取1時,shape的第二個位置是代表列,即標(biāo)量數(shù),就是在列上(標(biāo)量維度)進(jìn)行拼接,行數(shù)(向量數(shù))不變:
torch.cat([a,?b],?dim?=?1)
#?tensor([[0.,?0.,?0.,?1.,?1.,?1.],
##????????[0.,?0.,?0.,?1.,?1.,?1.]])
將dim與shape結(jié)合理解,是不是清晰明了了?
維度有疑惑的同學(xué),點(diǎn)擊這里進(jìn)行學(xué)習(xí)
stack堆疊函數(shù)
和拼接不同,堆疊不是將元素拆分重裝,而是將各參與堆疊的對象分裝到一個更高維度的張量里。
a?=?torch.zeros(2,?3)
a
#?tensor([[0.,?0.,?0.],
#?????????[0.,?0.,?0.]])
b?=?torch.ones(2,?3)
b
#?tensor([[1.,?1.,?1.],
#?????????[1.,?1.,?1.]])
堆疊之后,生成一個三維張量:
torch.stack([a,?b],?dim?=?0)
#?tensor([[[0.,?0.,?0.],
#??????????[0.,?0.,?0.]],
#?????????[[1.,?1.,?1.],
#??????????[1.,?1.,?1.]]])
torch.stack([a,?b],?dim?=?0).shape
#?torch.Size([2,?2,?3])
此例中,就是將兩個維度為1×2×3的張量堆疊為一個2×2×3的張量。
與cat的區(qū)別
拼接之后維度不變,堆疊之后維度升高。拼接是把一個個元素單獨(dú)提取出來之后再放到二維張量里,而堆疊則是直接將兩個二維向量封裝到一個三維張量中。因此,堆疊的要求更高,參與堆疊的張量必須形狀完全相同。
與python對比
a?=?[1,?2]
b?=?[3,?4]
cat拼接操作與list的extend相似,不會改變維度,只會在已有框架內(nèi)增加元素:
a.extend(b)
a
#?[1,?2,?3,?4]
stack堆疊操作與list的append相似,會改變維度:
a?=?[1,?2]
b?=?[3,?4]
a.append(b)
a
#?[1,?2,?[3,?4]]
推薦閱讀
推薦閱讀
