torch.cat 與 torch.view 差異

pytorch 官方說明:
https://pytorch.org/docs/stable/generated/torch.Tensor.view.html

官方範例:

>>> x = torch.randn(4, 4)
>>> x.size()
torch.Size([4, 4])
>>> y = x.view(16)
>>> y.size()
torch.Size([16])
>>> z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])

>>> a = torch.randn(1, 2, 3, 4)
>>> a.size()
torch.Size([1, 2, 3, 4])
>>> b = a.transpose(1, 2)  # Swaps 2nd and 3rd dimension
>>> b.size()
torch.Size([1, 3, 2, 4])
>>> c = a.view(1, 3, 2, 4)  # Does not change tensor layout in memory
>>> c.size()
torch.Size([1, 3, 2, 4])
>>> torch.equal(b, c)
False


The user is asking for a comparison between torch.cat and torch.view in PyTorch. I need to explain:

  1. torch.cat:
    • Functionality: Concatenation of tensors along a dimension.
    • Input: A sequence of tensors and a dimension (dim).
    • Output: A new tensor formed by joining the input tensors.
    • Key point: Changes the shape by increasing the size along the specified dimension.
  2. torch.view:
    • Functionality: Reshaping a tensor without changing its data.
    • Input: A tensor and a new shape.
    • Output: A new tensor with the same data but a different shape (view).
    • Key point: Data remains the same, only the interpretation of the shape changes. Number of elements must be preserved.
  3. Differences:
    • Purpose: torch.cat joins tensors, torch.view reshapes a single tensor.
    • Data: torch.cat creates new data by combining input tensors. torch.view does not create new data, only rearranges existing data.
    • Shape Change: torch.cat changes the shape by adding elements in a dimension. torch.view changes the shape without changing the total number of elements.
    • Input Type: torch.cat takes a sequence of tensors. torch.view takes a single tensor.
  4. Use Cases:
    • torch.cat: Combining outputs from different branches, skip connections in UNet (as in the previous question!), assembling batches from smaller chunks, etc.
    • torch.view: Preparing data for different layers (e.g., flattening for fully connected layers), rearranging dimensions for operations, etc.
  5. Analogy:
    • torch.cat: Like joining multiple pieces of rope end-to-end to make a longer rope (increasing length, new rope from pieces).
    • torch.view: Like re-organizing bricks from a stack into a row (same bricks, different arrangement).

I should explain these points clearly with code examples in Traditional Chinese to address the user’s query comprehensively. I also need to explicitly point out when each function should be used and when they cannot be used interchangeably, and highlight common errors/misconceptions. Including a table summarizing the key differences will also be helpful for quick understanding.


torch.cattorch.view 都是 PyTorch 中常用的張量操作函式,但它們的功能和用途截然不同。理解它們之間的差異對於有效地操作張量和構建神經網路至關重要。

1. torch.cat (Concatenate – 串聯)

  • 功能: torch.cat 的主要功能是將 多個張量 沿著指定的 維度 串聯 (concatenate) 在一起。你可以把它想像成將幾條繩子 頭尾相接 串成長繩。
  • 輸入:
    • tensors (tuple 或 list of Tensors): 這是必要的第一個參數,它是一個包含多個要串聯的張量的序列 (可以是 tuple 或 list)。重點是,所有要串聯的張量,除了指定串聯的維度之外,其他維度的形狀 (size) 必須完全相同。
    • dim (int): 這是必要的第二個參數,指定了要沿著哪個維度進行串聯。例如:
      • dim=0:沿著 第 0 維度 (行維度) 串聯。對於形狀為 (N, C, H, W) 的張量來說,dim=0 通常是批次 (batch) 維度。沿 dim=0 串聯會增加批次大小。
      • dim=1:沿著 第 1 維度 (通道維度) 串聯。對於形狀為 (N, C, H, W) 的張量來說,dim=1 是通道 (channel) 維度。沿 dim=1 串聯會增加通道數量。
      • dim=2dim=3:分別沿著高度維度和寬度維度串聯。
  • 輸出: torch.cat 會返回一個 新的張量,這個新張量包含了所有輸入張量的資料,並且形狀會根據串聯的維度而改變。串聯的維度大小會增加,而其他維度大小保持不變。
  • 使用情境:
    • 合併來自不同分支的特徵: 例如,在 UNet 架構中,跳躍連接會將編碼器路徑和解碼器路徑的特徵圖串聯在一起。 (正如我們在之前的問題中討論的)
    • 組合批次資料: 例如,當資料被分成小塊處理後,可以使用 torch.cat 將處理結果重新組合成完整的批次。
    • 序列資料處理: 在 RNN 或 Transformer 等模型中,處理序列資料時,可能會需要沿著序列長度維度串聯張量。

程式碼範例:torch.cat (沿通道維度 dim=1 串聯)

import torch

# 建立兩個範例張量,形狀為 (1, 3, 256, 256)
tensor1 = torch.randn(1, 3, 256, 256)
tensor2 = torch.randn(1, 3, 256, 256)

print("tensor1 shape:", tensor1.shape)
print("tensor2 shape:", tensor2.shape)

# 沿著通道維度 (dim=1) 串聯 tensor1 和 tensor2
concatenated_tensor = torch.cat((tensor1, tensor2), dim=1)

print("concatenated_tensor shape:", concatenated_tensor.shape)

輸出:

tensor1 shape: torch.Size([1, 3, 256, 256])
tensor2 shape: torch.Size([1, 3, 256, 256])
concatenated_tensor shape: torch.Size([1, 6, 256, 256])

說明: torch.cattensor1tensor2 沿著通道維度 (dim=1) 串聯,結果張量的通道數變成了 3 + 3 = 6,而其他維度 (batch size, 高度, 寬度) 維持不變。

2. torch.view (Reshape – 重新塑形)

  • 功能: torch.view 的主要功能是 改變張量的形狀 (shape),但 不會改變張量內部的資料。 你可以把它想像成將一堆積木 重新排列 成不同的形狀,但積木的總數和種類並沒有改變。
  • 輸入:
    • tensor (Tensor): 這是必要的第一個參數,即要重新塑形的張量。
    • shape (tuple 或 torch.Size): 這是必要的第二個參數,指定了新的形狀。 新的形狀必須與原始張量包含的元素總數相同。 否則 torch.view 會報錯。 可以使用 -1 來自動推斷某個維度的大小。
  • 輸出: torch.view 會返回一個 新的張量 (實際上是原始張量的 視圖 – view),這個新張量與原始張量 共享相同的資料儲存空間,只是對資料的 形狀解釋 不同。 重要的是,torch.view 不會複製資料,因此效率很高。 修改 view 會影響原始張量,反之亦然,除非原始張量使用 contiguous() 建立連續記憶體副本。
  • 使用情境:
    • 調整張量形狀以符合網路層的需求: 例如,在卷積層之後,可能需要將特徵圖 展平 (flatten) 成向量,才能輸入到全連接層。 這時就可以使用 torch.view 來改變張量的形狀。
    • 重新組織維度順序: 例如,將形狀為 (H, W, C) 的張量轉換成 (C, H, W),以符合 PyTorch 預期的通道優先順序。
    • 改變視角,方便操作: 例如,將一個 2D 張量 view 成 1D 張量,方便進行某些特定的運算。

程式碼範例:torch.view (改變形狀)

import torch

# 建立一個範例張量,形狀為 (4, 4)
original_tensor = torch.arange(16).reshape(4, 4) # 使用 arange 建立 0-15 的數字序列,並 reshape 成 4x4 矩陣
print("original_tensor:\n", original_tensor)
print("original_tensor shape:", original_tensor.shape)

# 使用 view 將 original_tensor 重新塑形為 (2, 8)
reshaped_tensor = original_tensor.view(2, 8)
print("\nreshaped_tensor:\n", reshaped_tensor)
print("reshaped_tensor shape:", reshaped_tensor.shape)

# 使用 view 將 original_tensor 展平成 1D 張量
flattened_tensor = original_tensor.view(-1) # 使用 -1 自動推斷維度大小
print("\nflattened_tensor:\n", flattened_tensor)
print("flattened_tensor shape:", flattened_tensor.shape)

輸出:

original_tensor:
 tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
original_tensor shape: torch.Size([4, 4])

reshaped_tensor:
 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15]])
reshaped_tensor shape: torch.Size([2, 8])

flattened_tensor:
 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
flattened_tensor shape: torch.Size([16])

說明:

  • original_tensor 是一個 4×4 的矩陣,包含 16 個元素 (4 * 4 = 16)。
  • reshaped_tensor = original_tensor.view(2, 8)original_tensor 重新塑形為 2×8 的矩陣,仍然包含 16 個元素 (2 * 8 = 16),但形狀改變了。 資料內容 (0 到 15 的數字) 沒有改變,只是排列方式不同了。
  • flattened_tensor = original_tensor.view(-1) 使用 -1 自動推斷維度大小,將 original_tensor 展平成一個 1 維向量,包含 16 個元素,形狀變成了 (16,)

3. torch.cattorch.view 的主要差異總結:

特性torch.cattorch.view
功能串聯 (Concatenate) 多個張量重新塑形 (Reshape) 單個張量
輸入多個張量 (tuple/list)單個張量
資料創建新資料,組合輸入張量的資料不改變資料,僅改變形狀的解釋
形狀改變改變串聯維度的大小,增加元素改變形狀,但 元素總數必須保持不變
記憶體創建新張量,可能分配新的記憶體創建視圖 (view),通常不分配新記憶體 (效率高)
使用情境合併不同來源的特徵,組裝批次,序列處理調整形狀以符合網路層需求,展平,維度重排

何時使用 torch.cat? 何時使用 torch.view

  • 使用 torch.cat 的時機:
    • 當你需要 合併 來自不同來源的張量時,例如:
      • 跳躍連接 (UNet)
      • 多分支網路 (Multi-branch Network)
      • 需要組裝資料批次時
    • 當你需要 增加 張量在 某個維度上的大小 時,例如通道數、批次大小、序列長度等。
    • 重點: 你需要將 不同的張量 合併成一個更大的張量。
  • 使用 torch.view 的時機:
    • 當你只需要 改變 一個張量的 形狀,而 資料本身不需要改變 時,例如:
      • 展平特徵圖 (Flatten Feature Map)
      • 調整維度順序 (Transpose Dimensions)
      • 將資料 Reshape 成模型需要的輸入格式
    • 當你想要 更方便地操作 張量,例如將一個多維張量轉換成一維向量進行處理。
    • 重點: 你只需要 改變單個張量的形狀,不需要合併多個張量。

常見的錯誤區別:

  • 誤用 torch.view 代替 torch.cat 有些人可能會錯誤地嘗試使用 torch.view 來合併多個張量。 這是 錯誤的,因為 torch.view 只能重新塑形 單個 張量,不能將多個張量合併。 如果要合併多個張量,必須使用 torch.cat 或其他相關的組合操作 (例如 torch.stack 等)。
  • torch.view 的形狀不相容: 使用 torch.view 時,新的形狀必須與原始張量包含的元素總數相同,這是非常重要的限制。 如果新的形狀導致元素數量不一致,torch.view 會拋出錯誤。 在重新塑形之前,務必確保新形狀的元素數量等於原始形狀的元素數量。

總結:

torch.cattorch.view 是 PyTorch 中兩個功能完全不同的張量操作函式。 torch.cat 用於 串聯多個張量,增加指定維度的大小,而 torch.view 用於 重新塑形單個張量,改變其形狀,但不改變資料本身。 理解它們的差異,並根據實際需求選擇正確的函式,是寫出高效且正確 PyTorch 程式碼的關鍵。 在 UnetSkipConnectionBlockforward 函數中使用 torch.cat 是為了實現跳躍連接的特徵融合,這是 UNet 架構的核心設計思想。

About the author

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *