繁體
  • 简体中文
  • 繁體中文

熱門資訊> 正文

斯坦福華人天團意外爆冷!AI用純CUDA-C編內核,竟干翻PyTorch?

2025-05-31 12:07

就在剛剛,斯坦福HAI華人大神團隊又出驚人神作了。

他們用純CUDA-C語言編寫的快速AI生成內核,竟然超越了PyTorch!

在這個過程中,完全不用藉助CUTLASS和Triton等庫和領域特定語言(DSL),就能讓性能表現接近PyTorch內置的、經過專家優化的標準生產級內核,甚至在某些情況下還更勝一籌。

作者團隊都是我們熟悉的名字——Anne Ouyang、Azalia Mirhoseini和Percy Liang,有趣的是,他們甚至直言,這個結果其實本不想拿出來發佈。

一經發布,這個發現就引爆了技術圈,現在已經登頂Hacker News總榜第二。

説起來,這個發現還有很多意外的成分。

本來,他們的目標是生成合成數據,來訓練更好的內核生成模型,合成數據生成的設計也十分簡單。

然而,意想不到的事情發生了,僅用於測試的合成數據生成本身,竟開始生成非常優秀的內核,甚至超越了人類專家優化的PyTorch基線,而且還利用了高級優化和硬件特性。

而在此前,這是一項很艱難的挑戰。

由此,研究者們決定提前撰寫博文,把自己的發現分享出來。

總結來説,研究的亮點成果如下:

  • 矩陣乘法(Matmul, FP32):性能達到PyTorch FP32 torch.matmul的101.3%

  • 二維卷積(Conv2D, FP32):性能達到PyTorch FP32 torch.nn.Conv2D的179.9%

  • Softmax(FP32):性能達到PyTorch FP32 torch.softmax的111.8%

  • 層歸一化(LayerNorm, FP32):性能達到PyTorch FP32 torch.nn.LayerNorm的484.4%

  • 二維卷積 + ReLU + 最大池化(Conv2D + ReLU + MaxPool, FP32):性能達到PyTorch FP32參考實現的 290.1%,達到PyTorch FP32 torch.compile()參考實現的189.0%

以上結果在英偉達L40S GPU上進行了基準測試,性能百分比定義為參考時間除以生成的內核時間。

網友:強制LLM推理,實在太有趣了

在Hacker News上,網友們也對此展開了熱烈討論。

比如為什麼使用FP32內核會比PyTorch更容易實現性能提升,理由就相當有趣。

如果AI真的能以更低成本,實現更優化的內核,的確潛力巨大。

最令人震撼的就是,無論是最近谷歌的AlphaEvolve,還是o3在Linux內核中發現了零日漏洞,都在提醒我們——

Gemini Pro 2.5和o3已經達到了一個全新的能力水平,那些曾經在其他模型上嘗試失敗的想法,現在突然奏效了。

可以説,我們已經到達了一個節點,LLM能比用人類快得多的速度進行迭代和測試,信息組合、進步和智能應用的蠻力,似乎正在成功!

接下來,我們來看看斯坦福研究者們博客中的具體內容。

博客全文

在博客中,研究者分享了具體方法、五個優化后的內核(包括4個基礎機器學習算子和1個AlexNet模塊的融合內核)、一個優化過程的實例,以及一些思考,關於這些發現對高性能內核生成可能意味着什麼。

可以説,這些內容將是他們后續探索的第一步。

方法

研究者們採用了KernelBench的任務設置(這是他們在2024年12月發佈的一款基於AI的內核生成基準測試)。

具體來説,給定一段torch代碼,LLM會編寫自定義內核來替換原有的torch算子,目標是實現加速。

依照KernelBench最初的設計,參考代碼默認使用FP32精度;在給定的容差閾值(1e-02)下,採用較低精度的解決方案也是被允許的。

此外,由於存在大量針對特定規模的優化手段,KernelBench中的每個問題都設定了具體的輸入大小。

因此,該基準測試旨在找出針對特定問題規模的最快內核,而非一個適用於任意問題規模的高速內核。

而且,研究者會同時運行torch參考代碼和生成的代碼,並通過在多種隨機輸入下比較兩者輸出的數值是否一致,來檢驗其正確性。

當前,在優化內核這個問題上,業界擴展測試時計算資源最常用的方法是順序修訂(sequential revision)。

這是一種多輪迭代的循環:模型首先對內核進行增量式修改,接着檢查其正確性和性能,然后根據結果再次嘗試。

也就是説,要麼修復有問題的內核,要麼進一步提升現有內核的性能。

這個循環過程非常直觀,也容易實現。模型會修復失效的內核,微調可用的內核,一步步優化出性能更佳的版本。

這種方法的主要侷限,在於優化思路缺乏多樣性。

順序循環往往容易陷入局部最優的困境,比如反覆嘗試同類型的轉換,或是在缺乏潛力的優化路徑上無休止地調整。

其結果便是測試時計算資源的低效利用,並且難以促使模型產生具有根本性創新的優化思路。

為解決這一問題,研究者引入了兩項關鍵改變:

  • 運用自然語言對優化思路進行推理

他們不再於每一步直接生成新的內核,而是以先前嘗試過的思路為條件,用自然語言生成優化思路,隨后將這些思路具化為新的代碼變體。

  • 在每個優化步驟進行分支擴展

他們不是每步只改進一個候選方案,而是進行分支擴展,讓每個思路都能派生出多種實現版本,其中性能最佳的內核將作為下一輪優化的種子。

(研究者也會保留一個表現優異的現有內核庫,用於提供種子)。

這種方式解鎖了大規模的並行處理能力,使他們能夠在每一輪探索截然不同的優化方向,避免陷入狹窄的優化路徑。

其結果是,這種測試時循環不再像順序修訂那般,僅僅是與編譯器「對話」,而是更接近一種結構化的探索性搜索。

這種搜索由明確的優化假設指導,並採用大規模並行評估的方式進行。

研究者運行了KernelBench第1級的10個問題,以進行測試。

他們調整了問題規模,以確保內核啟動開銷相對於問題的整體運行時間而言可以忽略不計。

然后,使用OpenAI o3和Gemini 2.5 Pro模型進行了5輪實驗。

下圖展示了首次發現性能最佳內核所在的輪次分佈情況。

可以看到,大多數最優結果出現在靠后的輪次(總共5輪),其中絕大部分出現在第4輪或第5輪。

隨着擴大搜索範圍,研究者還發現:許多高性能內核的優化策略高度相似,集中在少數幾種常見的模式上,這與他們手動編寫內核的經驗也是一致的。

主要的優化類別歸納如下——

  • 內存訪問優化:提升不同內存層級(全局內存、共享內存、寄存器)之間數據遷移的效率,並確保數據訪問方式能夠最大化帶寬、最小化衝突。

  • 異步操作與延迟隱藏:通過將耗時較長的操作(例如全局內存訪問)與計算或其他內存傳輸重疊執行,來隱藏其帶來的延迟。

  • 數據類型與精度優化:在允許的條件下,儘可能使用較低精度的數據類型(如FP16或BF16),以降低內存帶寬需求,提升緩存效率,並有望利用專門的硬件加速單元。

  • 計算與指令優化:提升算術運算本身的效率,削減指令數量,或利用專門的硬件指令。

  • 並行性與佔用率增強:最大化流式多處理器(SM)上活躍線程束(warp)的數量,以便更好地隱藏延迟,提高整體吞吐率。

  • 控制流與循環優化:減少由循環、分支及索引計算等引入的額外開銷。

總結

這次研究者採用的方法,與AI研究中一個日益顯著的趨勢不謀而合——

將強大的推理能力與對多個假設的並行探索相結合,能夠帶來性能的提升。

正如一些近期研究(例如AlphaEvolve、Gemini 2.5 Pro Deep Think)所強調的,我們並不總是需要大規模的重新訓練。

論文地址:https://storage.googleapis.com/deepmind-media/DeepMind.com/Blog/alphaevolve-a-gemini-powered-coding-agent-for-designing-advanced-algorithms/AlphaEvolve.pdf

有時,巧妙的搜索和分支策略便足以催生科學創新、攻克複雜難題,而藉助驗證器進行廣泛搜索,則可能帶來更大的收益。

然而,這並不意味着我們不需要進一步的訓練。

恰恰相反,研究者的這種方法,也有助於生成更優質的合成數據,用以改進未來的模型訓練(這需要更多的問題實例)。

因此,它既是一種強大的測試時擴展方法,也是我們邁向更智能、數據效率更高的模型開發之路的一步。

而且,這次研究者展現的僅僅是初步的成果。這些優化結果的質量看起來相當可觀,但仍有廣闊的提升空間,例如產生更優的優化思路、生成更高質量的最終代碼,以及將此方法應用於日益複雜的內核。

目前,研究者仍在積極改進的兩個具體例子包括:

  • FP16 Matmul:性能達到torch.matmul的52%

  • FP16 Flash Attention:性能達到torch.nn.functional.scaled_dot_product_attention的9%

在現代機器學習任務中,FP32的應用不如FP16或BF16普遍,並且在較新的硬件上,針對FP32的優化往往也更少。

這或許能部分解釋,為何基於FP32的內核更容易在性能上超越PyTorch。

作者介紹

Anne Ouyang 

Anne Ouyang目前是斯坦福大學計算機科學(CS)博士生,在Scaling Intelligence Lab(可擴展智能實驗室)進行研究。

她的研究興趣主要集中在可擴展的自我改進機器學習系統,同時也廣泛關注實證機器學習(empirical ML)和性能工程(performance engineering)。

此前,她在MIT獲得學士和碩士學位,並曾在NVIDIA cuDNN團隊工作,負責編寫CUDA內核,用於加速GPU上的深度學習工作負載。

Azalia Mirhoseini

Azalia Mirhoseini是斯坦福大學計算機科學助理教授,也是Scaling Intelligence Lab(可擴展智能實驗室)的創始人,並在Google DeepMind兼任高級研究科學家。

她的實驗室致力於開發可擴展的自主演進人工智能系統與方法論,以期推動通用人工智能的發展。

在加入斯坦福大學之前,她曾在Google Brain和Anthropic等業界頂尖的人工智能實驗室工作多年。

她過往的卓越成就包括:

  • 提出混合專家(MoE)神經架構——目前已被前沿的AI模型廣泛應用;

  • 領導AlphaChip項目——一項將深度強化學習用於佈局優化的開創性工作,併成功應用於谷歌AI加速器(TPU)及數據中心CPU等先進芯片的設計中;

  • 在測試時計算的Scaling方面有深入的研究

Percy Liang

Percy Liang是斯坦福大學計算機科學副教授,兼任基礎模型研究中心(CRFM)主任。同時也是CodaLab Worksheets的創建者,並藉此堅定倡導科研工作的可復現性。

他目前專注於通過開源和嚴格的基準測試,提升基礎模型(特別是大語言模型)的可及性與可理解性。

他曾圍繞機器學習和自然語言處理領域進行了廣泛研究,具體方向包括魯棒性、可解釋性、人機交互、學習理論、知識落地、語義學以及推理等。

此前,他於2004年在MIT獲得學士學位,並於2011年在UC伯克利獲得博士學位。

內核優化過程

最后,展示一個Conv2D自動生成思路的優化軌跡示例,torch參考基準時間為1.41毫秒。

第0輪:7.02毫秒,達到參考性能的20.1%

思路:給定pytorch代碼,用CUDA Kernel替換操作。

第1輪:7.54毫秒,達到參考性能的18.8% 

思路:通過使用__ldg加載不變的張量來利用只讀緩存。

第2輪:3.46毫秒,達到參考性能的41.0%

思路:將卷積轉換為FP16 Tensor-Core GEMM。

作者評論:這是一種算法優化,將卷積轉換為隱式GEMM,這對於在Tensor Core上高效運行卷積至關重要

第3輪:3.67毫秒,達到參考性能的38.7% 

思路:採用雙緩衝cp.async管線,使全局內存加載與Tensor-Core計算重疊。

第4輪:3.46毫秒,達到參考性能的41.0%

思路:給定pytorch代碼,使用隱式矩陣乘法(implicit matmul)的CUDA Kernel替換操作。給定的GEMM內核可能會有幫助。

作者評論:因為優化涉及到使用GEMM,所以在這一輪開始時,使用了一個之前生成的現有優秀GEMM內核作為種子,這個思路是手動編寫的

第5輪:1.91毫秒,達到參考性能的74.9% 

思路:在每個K-tile循環內部,預計算k_idx分解的內核/輸入索引並在共享內存中重用,以避免冗余算術運算。

第6輪:1.37毫秒,達到參考性能的103.6%

思路:在共享內存中預計算並緩存N維GEMM索引,以減少輸入數據加載循環內的冗余算術運算。

第7輪:1.38毫秒,達到參考性能的102.9%

思路:通過使用專用的每warp共享內存緩衝區來並行化CUDA內核輸出,以消除寫入全局內存時的warp級串行化。

第8輪:1.37毫秒,達到參考性能的103.6%

思路:在共享內存中預計算並緩存基礎輸入座標,以減少輸入數據加載過程中的冗余算術運算。

第9輪:1.36毫秒,達到參考性能的105.1%

思路:對B-fragment加載進行軟件流水線處理,使下一個B-tile的共享內存讀取與當前B-tile的WMMA計算重疊。

第10輪:1.07毫秒,達到參考性能的133.6%

思路:重用共享內存中預計算的N維GEMM分解來進行輸出地址計算,避免冗余且高成本的除法/模運算。

第11輪:1.21毫秒,達到參考性能的117.4% 

思路:在half WMMA運算中移除hi/lo分解,如果最終精度可接受,則依賴標準的FP16累積來提高性能。

第12輪:1.01毫秒,達到參考性能的141.2% 

思路:通過在流水線的加載階段動態計算K維索引,使用雙緩衝使Asub(權重)和Bsub(輸入)的K循環全局內存加載與MMA計算重疊。

第13輪:0.795毫秒,達到參考性能的179.9% 

思路:通過使用像half2這樣的更寬數據類型,為加載Asub_pipe和Bsub_pipe實現向量化的共享內存寫入。

最終代碼

最終生成的Conv2D內核代碼,使用了先進的CUDA技術,就是人類自己寫起來都很有挑戰性的那種!

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.cpp_extension import load_inlineconv2d_implicit_gemm_cuda_source = r"""#include <torch/extension.h>#include <ATen/cuda/CUDAContext.h> // For at::cuda::getCurrentCUDAStream()#include <mma.h>#include <cuda_fp16.h>using namespace nvcuda;// WMMA tile dimensions#define WMMA_M 16#define WMMA_N 16#define WMMA_K 16// Skew padding for shared memory to avoid bank conflicts#define SKEW_HALF 8 // 8 half elements (16 bytes)// CUDA built-in warpSize is 32 for supported architectures (sm_70+)// This constant is used for host-side configuration (e.g. blockDim)#define CUDA_WARP_SIZE_CONST 32 // Threadblock configuration#define WARPS_PER_BLOCK 8// THREADS_PER_BLOCK must be evaluatable by host compiler for blockDim configuration#define THREADS_PER_BLOCK (WARPS_PER_BLOCK * CUDA_WARP_SIZE_CONST) // Macro-tile dimensions computed by a threadblock// BLOCK_M_TILES_WMMA * WMMA_M = output channels processed by a block// BLOCK_N_TILES_WMMA * WMMA_N = output spatial elements processed by a block#define BLOCK_M_TILES_WMMA 8#define BLOCK_N_TILES_WMMA 8#define TILE_M_PER_BLOCK (BLOCK_M_TILES_WMMA * WMMA_M) // e.g., 8 * 16 = 128 (for C_out dimension)#define TILE_N_PER_BLOCK (BLOCK_N_TILES_WMMA * WMMA_N) // e.g., 8 * 16 = 128 (for N_batch * H_out * W_out dimension)// Vector size for shared memory writes (half2)#define VECTOR_SIZE_H2 2// Struct to hold precomputed N-dimension GEMM indicesstruct NDecomposed {    int ow_eff;    int oh_eff;    int n_batch_idx;    bool isValidPixel; // True if this pixel_idx is within N_gemm bounds    int h_in_base;     int w_in_base; };__global__ void conv2d_implicit_gemm_wmma_kernel(    const float* __restrict__ input_ptr,    // Input: (N, Cin, Hin, Win)    const float* __restrict__ weight_ptr,   // Weights: (Cout, Cin, Kh, Kw)    const float* __restrict__ bias_ptr,     // Bias: (Cout) or nullptr    float* __restrict__ output_ptr,         // Output: (N, Cout, Hout, Wout)    const int N_batch, const int C_in, const int H_in, const int W_in,    const int C_out, const int K_h, const int K_w,    const int stride_h, const int stride_w,    const int pad_h, const int pad_w,    const int H_out, const int W_out,    const int M_gemm, // C_out    const int N_gemm, // N_batch * H_out * W_out    const int K_gemm  // C_in * K_h * K_w) {    // Thread identification    const int warp_id = threadIdx.x / warpSize;        // 0 .. WARPS_PER_BLOCK-1    const int lane_id = threadIdx.x % warpSize;        // 0 .. 31 (or warpSize-1)    // Top-left corner of the macro-tile this block is responsible for in GEMM terms    const int block_row_gemm_start = TILE_M_PER_BLOCK * blockIdx.y;    const int block_col_gemm_start = TILE_N_PER_BLOCK * blockIdx.x;    // Shared memory for tiles of A (weights) and B (input/im2col) - Double Buffered for K-loop pipelining    __shared__ half Asub_pipe[2][TILE_M_PER_BLOCK][WMMA_K + SKEW_HALF];    __shared__ half Bsub_pipe[2][TILE_N_PER_BLOCK][WMMA_K + SKEW_HALF];    // Shared memory for precomputed N-indices    __shared__ NDecomposed n_params_sh[TILE_N_PER_BLOCK];    // Shared memory for output stage (per-warp buffers)    __shared__ float C_shmem_output_buffers[WARPS_PER_BLOCK][WMMA_M][WMMA_N];    // Accumulator fragments per warp.    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag[BLOCK_N_TILES_WMMA];    #pragma unroll    for (int i = 0; i < BLOCK_N_TILES_WMMA; ++i) {        wmma::fill_fragment(acc_frag[i], 0.0f);    }    // Populate n_params_sh once at the beginning of the kernel    if (threadIdx.x < TILE_N_PER_BLOCK) {        int r_b_tile_idx = threadIdx.x;         int current_pixel_idx = block_col_gemm_start + r_b_tile_idx;        if (current_pixel_idx < N_gemm) {            n_params_sh[r_b_tile_idx].ow_eff = current_pixel_idx % W_out;            int temp_div_wout = current_pixel_idx / W_out;            n_params_sh[r_b_tile_idx].oh_eff = temp_div_wout % H_out;            n_params_sh[r_b_tile_idx].n_batch_idx = temp_div_wout / H_out;            n_params_sh[r_b_tile_idx].isValidPixel = true;            n_params_sh[r_b_tile_idx].h_in_base = n_params_sh[r_b_tile_idx].oh_eff * stride_h - pad_h;            n_params_sh[r_b_tile_idx].w_in_base = n_params_sh[r_b_tile_idx].ow_eff * stride_w - pad_w;        } else {            n_params_sh[r_b_tile_idx].isValidPixel = false;            n_params_sh[r_b_tile_idx].ow_eff = 0;             n_params_sh[r_b_tile_idx].oh_eff = 0;            n_params_sh[r_b_tile_idx].n_batch_idx = 0;            n_params_sh[r_b_tile_idx].h_in_base = 0;             n_params_sh[r_b_tile_idx].w_in_base = 0;        }    }    __syncthreads();    // Constants for vectorized shared memory loading    // Number of half2 elements along K-dim for a shared memory tile row    const int NUM_H2_ELEMENTS_IN_K_DIM = WMMA_K / VECTOR_SIZE_H2;    // Number of thread groups, where each group has NUM_H2_ELEMENTS_IN_K_DIM threads.    // Each group is responsible for loading the K-dimension for one M-row (for A) or N-row (for B) at a time,    // iterating over M-rows or N-rows with this step size.    const int NUM_ROW_PROCESSING_GROUPS = THREADS_PER_BLOCK / NUM_H2_ELEMENTS_IN_K_DIM;    // --- K-Loop Pipelining ---    int num_k_tiles = (K_gemm + WMMA_K - 1) / WMMA_K;
    // --- Prologue: Load first k-tile (k_tile_iter = 0) into pipe_idx = 0 ---    if (num_k_tiles > 0) {         int k_tile_start_prologue = 0;         int current_pipe_idx_prologue = 0;         // Load Asub_pipe[0] for k_tile_iter = 0        {            // This thread is responsible for the 'h2_idx_in_k_dim_A'-th half2 element            // in the K-dimension of the shared memory tile.            int h2_idx_in_k_dim_A = threadIdx.x % NUM_H2_ELEMENTS_IN_K_DIM;            // Starting 'half' index in shared memory for this half2 write.            int shmem_k_start_for_h2_A = h2_idx_in_k_dim_A * VECTOR_SIZE_H2;            // Global k-indices for the two half elements.            int k_global_A_0 = k_tile_start_prologue + shmem_k_start_for_h2_A;            int k_global_A_1 = k_tile_start_prologue + shmem_k_start_for_h2_A + 1;            // Decompose k_global_A_0            int kw_eff_reg_A_0 = 0, kh_eff_reg_A_0 = 0, ic_eff_reg_A_0 = 0;            bool is_valid_k_A_0 = (k_global_A_0 < K_gemm);            if (is_valid_k_A_0) {                kw_eff_reg_A_0 = k_global_A_0 % K_w;                int temp_div_kw_A_0 = k_global_A_0 / K_w;                kh_eff_reg_A_0 = temp_div_kw_A_0 % K_h;                ic_eff_reg_A_0 = temp_div_kw_A_0 / K_h;            }            // Decompose k_global_A_1            int kw_eff_reg_A_1 = 0, kh_eff_reg_A_1 = 0, ic_eff_reg_A_1 = 0;            bool is_valid_k_A_1 = (k_global_A_1 < K_gemm);            if (is_valid_k_A_1) {                kw_eff_reg_A_1 = k_global_A_1 % K_w;                int temp_div_kw_A_1 = k_global_A_1 / K_w;                kh_eff_reg_A_1 = temp_div_kw_A_1 % K_h;                ic_eff_reg_A_1 = temp_div_kw_A_1 / K_h;            }
            // This thread belongs to 'm_row_group_id_A'-th group of threads.            // This group iterates over M-rows of the Asub_pipe tile.            int m_row_group_id_A = threadIdx.x / NUM_H2_ELEMENTS_IN_K_DIM;            for (int r_a_tile_base = m_row_group_id_A; r_a_tile_base < TILE_M_PER_BLOCK; r_a_tile_base += NUM_ROW_PROCESSING_GROUPS) {                int oc_idx = block_row_gemm_start + r_a_tile_base;                float weight_val_0 = 0.0f;                if (oc_idx < C_out && is_valid_k_A_0) {                    weight_val_0 = weight_ptr[oc_idx * C_in * K_h * K_w +                                              ic_eff_reg_A_0 * K_h * K_w +                                              kh_eff_reg_A_0 * K_w +                                              kw_eff_reg_A_0];                }                float weight_val_1 = 0.0f;                if (oc_idx < C_out && is_valid_k_A_1) {                    weight_val_1 = weight_ptr[oc_idx * C_in * K_h * K_w +                                              ic_eff_reg_A_1 * K_h * K_w +                                              kh_eff_reg_A_1 * K_w +                                              kw_eff_reg_A_1];                }                half2* smem_ptr_h2_A = reinterpret_cast<half2*>(                    &Asub_pipe[current_pipe_idx_prologue][r_a_tile_base][shmem_k_start_for_h2_A]                );                *smem_ptr_h2_A = make_half2(__float2half(weight_val_0), __float2half(weight_val_1));            }        }        // Load Bsub_pipe[0] for k_tile_iter = 0        {            int h2_idx_in_k_dim_B = threadIdx.x % NUM_H2_ELEMENTS_IN_K_DIM;            int shmem_k_start_for_h2_B = h2_idx_in_k_dim_B * VECTOR_SIZE_H2;            int k_global_B_0 = k_tile_start_prologue + shmem_k_start_for_h2_B;            int k_global_B_1 = k_tile_start_prologue + shmem_k_start_for_h2_B + 1;            int kw_eff_reg_B_0 = 0, kh_eff_reg_B_0 = 0, ic_eff_reg_B_0 = 0;            bool is_valid_k_B_0 = (k_global_B_0 < K_gemm);            if (is_valid_k_B_0) {                kw_eff_reg_B_0 = k_global_B_0 % K_w;                int temp_div_kw_B_0 = k_global_B_0 / K_w;                kh_eff_reg_B_0 = temp_div_kw_B_0 % K_h;                ic_eff_reg_B_0 = temp_div_kw_B_0 / K_h;            }            int kw_eff_reg_B_1 = 0, kh_eff_reg_B_1 = 0, ic_eff_reg_B_1 = 0;            bool is_valid_k_B_1 = (k_global_B_1 < K_gemm);            if (is_valid_k_B_1) {                kw_eff_reg_B_1 = k_global_B_1 % K_w;                int temp_div_kw_B_1 = k_global_B_1 / K_w;                kh_eff_reg_B_1 = temp_div_kw_B_1 % K_h;                ic_eff_reg_B_1 = temp_div_kw_B_1 / K_h;            }            int n_row_group_id_B = threadIdx.x / NUM_H2_ELEMENTS_IN_K_DIM;            for (int r_b_tile_base = n_row_group_id_B; r_b_tile_base < TILE_N_PER_BLOCK; r_b_tile_base += NUM_ROW_PROCESSING_GROUPS) {                float input_val_0 = 0.0f;                if (n_params_sh[r_b_tile_base].isValidPixel && is_valid_k_B_0) {                    const NDecomposed& current_n_params = n_params_sh[r_b_tile_base];                    int h_in_eff_0 = current_n_params.h_in_base + kh_eff_reg_B_0;                    int w_in_eff_0 = current_n_params.w_in_base + kw_eff_reg_B_0;                    if (h_in_eff_0 >= 0 && h_in_eff_0 < H_in && w_in_eff_0 >= 0 && w_in_eff_0 < W_in) {                        input_val_0 = input_ptr[current_n_params.n_batch_idx * C_in * H_in * W_in +                                              ic_eff_reg_B_0 * H_in * W_in +                                              h_in_eff_0 * W_in +                                              w_in_eff_0];                    }                }                float input_val_1 = 0.0f;                 if (n_params_sh[r_b_tile_base].isValidPixel && is_valid_k_B_1) {                    const NDecomposed& current_n_params = n_params_sh[r_b_tile_base];                    int h_in_eff_1 = current_n_params.h_in_base + kh_eff_reg_B_1;                    int w_in_eff_1 = current_n_params.w_in_base + kw_eff_reg_B_1;                     if (h_in_eff_1 >= 0 && h_in_eff_1 < H_in && w_in_eff_1 >= 0 && w_in_eff_1 < W_in) {                        input_val_1 = input_ptr[current_n_params.n_batch_idx * C_in * H_in * W_in +                                              ic_eff_reg_B_1 * H_in * W_in +                                              h_in_eff_1 * W_in +                                              w_in_eff_1];                    }                }                half2* smem_ptr_h2_B = reinterpret_cast<half2*>(                    &Bsub_pipe[current_pipe_idx_prologue][r_b_tile_base][shmem_k_start_for_h2_B]                );                *smem_ptr_h2_B = make_half2(__float2half(input_val_0), __float2half(input_val_1));            }        }    }    // Loop over the K_gemm dimension in tiles of WMMA_K    for (int k_tile_iter = 0; k_tile_iter < num_k_tiles; ++k_tile_iter) {        __syncthreads(); // Sync point for pipelining        int compute_pipe_idx = k_tile_iter % 2;        int load_pipe_idx = (k_tile_iter + 1) % 2;        // --- Load Stage for next k-tile (k_tile_iter + 1) into load_pipe_idx ---        int k_tile_start_for_load = (k_tile_iter + 1) * WMMA_K;        if (k_tile_start_for_load < K_gemm) {             // Load Asub_pipe[load_pipe_idx]            {                 int h2_idx_in_k_dim_A = threadIdx.x % NUM_H2_ELEMENTS_IN_K_DIM;                int shmem_k_start_for_h2_A = h2_idx_in_k_dim_A * VECTOR_SIZE_H2;                int k_global_A_0 = k_tile_start_for_load + shmem_k_start_for_h2_A;                int k_global_A_1 = k_tile_start_for_load + shmem_k_start_for_h2_A + 1;                int kw_eff_reg_A_0 = 0, kh_eff_reg_A_0 = 0, ic_eff_reg_A_0 = 0;                bool is_valid_k_A_0 = (k_global_A_0 < K_gemm);                if (is_valid_k_A_0) {                    kw_eff_reg_A_0 = k_global_A_0 % K_w;                    int temp_div_kw_A_0 = k_global_A_0 / K_w;                    kh_eff_reg_A_0 = temp_div_kw_A_0 % K_h;                    ic_eff_reg_A_0 = temp_div_kw_A_0 / K_h;                }                int kw_eff_reg_A_1 = 0, kh_eff_reg_A_1 = 0, ic_eff_reg_A_1 = 0;                bool is_valid_k_A_1 = (k_global_A_1 < K_gemm);                if (is_valid_k_A_1) {                    kw_eff_reg_A_1 = k_global_A_1 % K_w;                    int temp_div_kw_A_1 = k_global_A_1 / K_w;                    kh_eff_reg_A_1 = temp_div_kw_A_1 % K_h;                    ic_eff_reg_A_1 = temp_div_kw_A_1 / K_h;                }
                int m_row_group_id_A = threadIdx.x / NUM_H2_ELEMENTS_IN_K_DIM;                for (int r_a_tile_base = m_row_group_id_A; r_a_tile_base < TILE_M_PER_BLOCK; r_a_tile_base += NUM_ROW_PROCESSING_GROUPS) {                    int oc_idx = block_row_gemm_start + r_a_tile_base;                    float weight_val_0 = 0.0f;                    if (oc_idx < C_out && is_valid_k_A_0) {                        weight_val_0 = weight_ptr[oc_idx * C_in * K_h * K_w +                                                  ic_eff_reg_A_0 * K_h * K_w +                                                  kh_eff_reg_A_0 * K_w +                                                  kw_eff_reg_A_0];                    }                    float weight_val_1 = 0.0f;                    if (oc_idx < C_out && is_valid_k_A_1) {                        weight_val_1 = weight_ptr[oc_idx * C_in * K_h * K_w +                                                  ic_eff_reg_A_1 * K_h * K_w +                                                  kh_eff_reg_A_1 * K_w +                                                  kw_eff_reg_A_1];                    }                    half2* smem_ptr_h2_A = reinterpret_cast<half2*>(                        &Asub_pipe[load_pipe_idx][r_a_tile_base][shmem_k_start_for_h2_A]                    );                    *smem_ptr_h2_A = make_half2(__float2half(weight_val_0), __float2half(weight_val_1));                }            }             // Load Bsub_pipe[load_pipe_idx]            {                 int h2_idx_in_k_dim_B = threadIdx.x % NUM_H2_ELEMENTS_IN_K_DIM;                int shmem_k_start_for_h2_B = h2_idx_in_k_dim_B * VECTOR_SIZE_H2;                int k_global_B_0 = k_tile_start_for_load + shmem_k_start_for_h2_B;                int k_global_B_1 = k_tile_start_for_load + shmem_k_start_for_h2_B + 1;                int kw_eff_reg_B_0 = 0, kh_eff_reg_B_0 = 0, ic_eff_reg_B_0 = 0;                bool is_valid_k_B_0 = (k_global_B_0 < K_gemm);                if (is_valid_k_B_0) {                    kw_eff_reg_B_0 = k_global_B_0 % K_w;                    int temp_div_kw_B_0 = k_global_B_0 / K_w;                    kh_eff_reg_B_0 = temp_div_kw_B_0 % K_h;                    ic_eff_reg_B_0 = temp_div_kw_B_0 / K_h;                }                int kw_eff_reg_B_1 = 0, kh_eff_reg_B_1 = 0, ic_eff_reg_B_1 = 0;                bool is_valid_k_B_1 = (k_global_B_1 < K_gemm);                if (is_valid_k_B_1) {                    kw_eff_reg_B_1 = k_global_B_1 % K_w;                    int temp_div_kw_B_1 = k_global_B_1 / K_w;                    kh_eff_reg_B_1 = temp_div_kw_B_1 % K_h;                    ic_eff_reg_B_1 = temp_div_kw_B_1 / K_h;                }                int n_row_group_id_B = threadIdx.x / NUM_H2_ELEMENTS_IN_K_DIM;                for (int r_b_tile_base = n_row_group_id_B; r_b_tile_base < TILE_N_PER_BLOCK; r_b_tile_base += NUM_ROW_PROCESSING_GROUPS) {                    float input_val_0 = 0.0f;                    if (n_params_sh[r_b_tile_base].isValidPixel && is_valid_k_B_0) {                        const NDecomposed& current_n_params = n_params_sh[r_b_tile_base];                        int h_in_eff_0 = current_n_params.h_in_base + kh_eff_reg_B_0;                        int w_in_eff_0 = current_n_params.w_in_base + kw_eff_reg_B_0;                        if (h_in_eff_0 >= 0 && h_in_eff_0 < H_in && w_in_eff_0 >= 0 && w_in_eff_0 < W_in) {                            input_val_0 = input_ptr[current_n_params.n_batch_idx * C_in * H_in * W_in +                                                  ic_eff_reg_B_0 * H_in * W_in +                                                  h_in_eff_0 * W_in +                                                  w_in_eff_0];                        }                    }                    float input_val_1 = 0.0f;                    if (n_params_sh[r_b_tile_base].isValidPixel && is_valid_k_B_1) {                        const NDecomposed& current_n_params = n_params_sh[r_b_tile_base];                        int h_in_eff_1 = current_n_params.h_in_base + kh_eff_reg_B_1;                        int w_in_eff_1 = current_n_params.w_in_base + kw_eff_reg_B_1;                        if (h_in_eff_1 >= 0 && h_in_eff_1 < H_in && w_in_eff_1 >= 0 && w_in_eff_1 < W_in) {                            input_val_1 = input_ptr[current_n_params.n_batch_idx * C_in * H_in * W_in +                                                  ic_eff_reg_B_1 * H_in * W_in +                                                  h_in_eff_1 * W_in +                                                  w_in_eff_1];                        }                    }                    half2* smem_ptr_h2_B = reinterpret_cast<half2*>(                        &Bsub_pipe[load_pipe_idx][r_b_tile_base][shmem_k_start_for_h2_B]                    );                    *smem_ptr_h2_B = make_half2(__float2half(input_val_0), __float2half(input_val_1));                }            }         }        // --- Compute Stage for current k-tile (k_tile_iter) using compute_pipe_idx ---        int a_row_start_in_tile = warp_id * WMMA_M;         wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;        wmma::load_matrix_sync(a_frag, &Asub_pipe[compute_pipe_idx][a_row_start_in_tile][0], WMMA_K + SKEW_HALF);        wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag_inner_pipe[2];        if (BLOCK_N_TILES_WMMA > 0) {            int b_col_start_in_tile_current = 0 * WMMA_N;             wmma::load_matrix_sync(b_frag_inner_pipe[0], &Bsub_pipe[compute_pipe_idx][b_col_start_in_tile_current][0], WMMA_K + SKEW_HALF);        }
        int current_inner_pipe_idx = 0;        #pragma unroll        for (int n_tile = 0; n_tile < BLOCK_N_TILES_WMMA; ++n_tile) {            int next_inner_pipe_idx = 1 - current_inner_pipe_idx;            if (n_tile < BLOCK_N_TILES_WMMA - 1) {                int b_col_start_in_tile_next = (n_tile + 1) * WMMA_N;                wmma::load_matrix_sync(b_frag_inner_pipe[next_inner_pipe_idx], &Bsub_pipe[compute_pipe_idx][b_col_start_in_tile_next][0], WMMA_K + SKEW_HALF);            }            wmma::mma_sync(acc_frag[n_tile], a_frag, b_frag_inner_pipe[current_inner_pipe_idx], acc_frag[n_tile]);
            current_inner_pipe_idx = next_inner_pipe_idx;        }    }    __syncthreads();     // Store results from accumulator fragments to global memory    #pragma unroll    for (int n_tile = 0; n_tile < BLOCK_N_TILES_WMMA; ++n_tile) {        wmma::store_matrix_sync(&C_shmem_output_buffers[warp_id][0][0], acc_frag[n_tile], WMMA_N, wmma::mem_row_major);        for (int elem_idx_in_frag = lane_id; elem_idx_in_frag < WMMA_M * WMMA_N; elem_idx_in_frag += warpSize) {            int r_frag = elem_idx_in_frag / WMMA_N;            int c_frag = elem_idx_in_frag % WMMA_N;            int oc_idx = block_row_gemm_start + (warp_id * WMMA_M) + r_frag;
            int offset_in_block_N_processing = (n_tile * WMMA_N) + c_frag;            if (oc_idx < C_out && offset_in_block_N_processing < TILE_N_PER_BLOCK &&                 n_params_sh[offset_in_block_N_processing].isValidPixel) {                const NDecomposed& current_n_params = n_params_sh[offset_in_block_N_processing];                int ow_eff = current_n_params.ow_eff;                int oh_eff = current_n_params.oh_eff;                int n_batch_idx = current_n_params.n_batch_idx;                float val = C_shmem_output_buffers[warp_id][r_frag][c_frag];                if (bias_ptr != nullptr) {                    val += bias_ptr[oc_idx];                }                output_ptr[n_batch_idx * C_out * H_out * W_out +                           oc_idx * H_out * W_out +                           oh_eff * W_out +                           ow_eff] = val;            }        }    }}torch::Tensor conv2d_implicit_gemm_cuda(    torch::Tensor input, torch::Tensor weight, torch::Tensor bias,    int N_batch, int C_in, int H_in, int W_in,    int C_out, int K_h, int K_w,    int stride_h, int stride_w, int pad_h, int pad_w,    int H_out, int W_out) {    TORCH_CHECK(input.device().is_cuda(), "Input must be a CUDA tensor");    TORCH_CHECK(weight.device().is_cuda(), "Weight must be a CUDA tensor");    TORCH_CHECK(input.dtype() == torch::kFloat32, "Input must be float32");    TORCH_CHECK(weight.dtype() == torch::kFloat32, "Weight must be float32");    if (bias.defined()) {        TORCH_CHECK(bias.device().is_cuda(), "Bias must be a CUDA tensor");        TORCH_CHECK(bias.dtype() == torch::kFloat32, "Bias must be float32");        TORCH_CHECK(bias.dim() == 1 && bias.size(0) == C_out, "Bias has wrong shape");    }    TORCH_CHECK(input.dim() == 4, "Input must be 4D");    TORCH_CHECK(weight.dim() == 4, "Weight must be 4D");    TORCH_CHECK(input.size(0) == N_batch, "Input N_batch mismatch");    TORCH_CHECK(input.size(1) == C_in, "Input C_in mismatch");    TORCH_CHECK(input.size(2) == H_in, "Input H_in mismatch");    TORCH_CHECK(input.size(3) == W_in, "Input W_in mismatch");    TORCH_CHECK(weight.size(0) == C_out, "Weight C_out mismatch");    TORCH_CHECK(weight.size(1) == C_in, "Weight C_in mismatch");    TORCH_CHECK(weight.size(2) == K_h, "Weight K_h mismatch");    TORCH_CHECK(weight.size(3) == K_w, "Weight K_w mismatch");    auto output = torch::zeros({N_batch, C_out, H_out, W_out}, input.options());    const int M_gemm = C_out;    const int N_gemm = N_batch * H_out * W_out;    const int K_gemm = C_in * K_h * K_w;    if (M_gemm == 0 || N_gemm == 0) {         return output;    }    if (K_gemm == 0) {          if (bias.defined()) {             output = output + bias.reshape({1, C_out, 1, 1});        }        return output;     }    dim3 block_dim(THREADS_PER_BLOCK);    dim3 grid_dim(        (N_gemm + TILE_N_PER_BLOCK - 1) / TILE_N_PER_BLOCK,         (M_gemm + TILE_M_PER_BLOCK - 1) / TILE_M_PER_BLOCK      );    const float* bias_ptr_data = bias.defined() ? bias.data_ptr<float>() : nullptr;    cudaStream_t stream = at::cuda::getCurrentCUDAStream();    conv2d_implicit_gemm_wmma_kernel<<<grid_dim, block_dim, 0, stream>>>(        input.data_ptr<float>(),        weight.data_ptr<float>(),        bias_ptr_data,        output.data_ptr<float>(),        N_batch, C_in, H_in, W_in,        C_out, K_h, K_w,        stride_h, stride_w, pad_h, pad_w,        H_out, W_out,        M_gemm, N_gemm, K_gemm    );
    AT_CUDA_CHECK(cudaGetLastError());    return output;}"""conv2d_implicit_gemm_cuda_declaration = r"""torch::Tensor conv2d_implicit_gemm_cuda(    torch::Tensor input, torch::Tensor weight, torch::Tensor bias,    int N_batch, int C_in, int H_in, int W_in,    int C_out, int K_h, int K_w,    int stride_h, int stride_w, int pad_h, int pad_w,    int H_out, int W_out);"""# JIT compile the CUDA kernelcustom_conv2d_wmma_ops = load_inline(    name="custom_conv2d_wmma_ops_optimized_k_pipe_vec_smem", # Changed name to avoid collision    cpp_sources=conv2d_implicit_gemm_cuda_declaration,    cuda_sources=conv2d_implicit_gemm_cuda_source,    functions=["conv2d_implicit_gemm_cuda"],    verbose=True,     extra_cuda_cflags=["-arch=sm_70", "--use_fast_math", "-std=c++17"] )class ModelNew(nn.Module):    def __init__(self, num_classes=1000): # num_classes is part of original signature, kept for consistency        super(ModelNew, self).__init__()
        # Define Conv1 parameters (matching the original model)        self.in_channels = 3        self.out_channels = 96        self.kernel_size_val = 11 # Assuming square kernel        self.stride_val = 4       # Assuming square stride        self.padding_val = 2      # Assuming square padding        # Create a temporary Conv2d layer to initialize weights and bias        temp_conv = nn.Conv2d(            in_channels=self.in_channels,             out_channels=self.out_channels,             kernel_size=self.kernel_size_val,             stride=self.stride_val,             padding=self.padding_val,            bias=True # nn.Conv2d has bias=True by default        )        self.conv1_weight = nn.Parameter(temp_conv.weight.detach().clone())        if temp_conv.bias is not None:            self.conv1_bias = nn.Parameter(temp_conv.bias.detach().clone())        else:            # Correctly register 'conv1_bias' as None if not present            self.register_parameter('conv1_bias', None)         self.custom_conv_op = custom_conv2d_wmma_ops.conv2d_implicit_gemm_cuda    def forward(self, x):        N_batch = x.size(0)        # C_in_runtime = x.size(1) # Should match self.in_channels        H_in = x.size(2)        W_in = x.size(3)        # Calculate output dimensions        H_out = (H_in + 2 * self.padding_val - self.kernel_size_val) // self.stride_val + 1        W_out = (W_in + 2 * self.padding_val - self.kernel_size_val) // self.stride_val + 1
        # Bias tensor handling: pass an undefined tensor if bias is None.        # The C++ TORCH_CHECK(bias.defined()) handles this by providing nullptr to kernel.        bias_tensor = self.conv1_bias if self.conv1_bias is not None else torch.Tensor()        x = self.custom_conv_op(            x, self.conv1_weight, bias_tensor,            N_batch, self.in_channels, H_in, W_in,            self.out_channels, self.kernel_size_val, self.kernel_size_val, # K_h, K_w            self.stride_val, self.stride_val, # stride_h, stride_w            self.padding_val, self.padding_val, # pad_h, pad_w            H_out, W_out        )        return x

(聲明:本文僅代表作者觀點,不代表新浪網立場。)

風險及免責提示:以上內容僅代表作者的個人立場和觀點,不代表華盛的任何立場,華盛亦無法證實上述內容的真實性、準確性和原創性。投資者在做出任何投資決定前,應結合自身情況,考慮投資產品的風險。必要時,請諮詢專業投資顧問的意見。華盛不提供任何投資建議,對此亦不做任何承諾和保證。