- HipKittens คือ ชุดโปรแกรมมิงพริมิตีฟ ที่ออกแบบมาเพื่อดึงศักยภาพของ AMD GPU ออกมาให้ได้มากที่สุด โดยปรับแต่งการเข้าถึงหน่วยความจำ การจัดตารางทำงาน และการนำแคชกลับมาใช้ซ้ำให้เหมาะสม
- AMD MI355X GPU มีโครงสร้าง 256 compute units และ 8 chiplets (XCD) พร้อม ไฟล์รีจิสเตอร์ขนาดใหญ่ และ ชุดคำสั่ง matrix core แบบละเอียด
- ต่างจาก NVIDIA ตรงที่ AMD ไม่มี การจัดสรรรีจิสเตอร์ใหม่, คำสั่งเมทริกซ์แบบอะซิงโครนัส และ mbarrier ทำให้ 8-wave ping-pong และ 4-wave interleave มีประสิทธิภาพมากกว่า wave specialization
- HipKittens ใช้ การจัดตารางแบบรับรู้ชิปเล็ต (grid) เพื่อปรับปรุง locality ของแคช L2 และ LLC และเพิ่ม แบนด์วิดท์สูงสุดและ TFLOPS ในงาน GEMM และ Attention
- แนวทางนี้ช่วยชดเชย ความไม่สมบูรณ์ของซอฟต์แวร์ใน ecosystem ของ AMD GPU และเป็นฐานสำหรับเพิ่ม ความสามารถในการขยายตัวของ AI computing บนฮาร์ดแวร์ที่หลากหลาย
สถาปัตยกรรมและลักษณะประสิทธิภาพของ AMD CDNA GPU
- AMD MI355X GPU มี 256 compute units (CU) โดยแต่ละ CU ประกอบด้วย SIMD 4 ชุด
- หนึ่ง SIMD จะรัน wave ที่มี 64 เธรด ซึ่งต่างจาก warp ของ NVIDIA ที่มี 32 เธรด
- MI355X มี SRAM 165KB หรือราว 70% ของ B200 และไม่มีฟีเจอร์ คำสั่งคูณเมทริกซ์แบบอะซิงโครนัส, การจัดสรรรีจิสเตอร์ใหม่, tensor memory accelerator, mbarrier
- ในทางกลับกัน มันมี ไฟล์รีจิสเตอร์ใหญ่กว่า 2 เท่า และ จำนวนโปรเซสเซอร์มากกว่า 60% (256 CU เทียบกับ 160 SM)
- รองรับ ชุดคำสั่ง matrix core ขนาดเล็กและละเอียด และมีฟังก์ชัน โหลดจาก global ไป shared memory โดยตรง (คล้าย TMA)
- AMD ใช้ สถาปัตยกรรมแบบชิปเล็ต ที่ประกอบด้วย 8 chiplets (XCD) โดยแต่ละ XCD มีแคช L2 ของตัวเอง และมี แคช LLC อยู่ชั้นบน
- ตามตาราง MI355X มีสมรรถนะประมวลผล BF16 2.5 PFLOPs, MXFP8 5.0 PFLOPs, MXFP6 10.1 PFLOPs พร้อม หน่วยความจำ 288GB และแบนด์วิดท์ 8TB/s
ความท้าทายในการออกแบบเคอร์เนลสำหรับ AMD
- การปรับแต่งการเข้าถึงหน่วยความจำ: เนื่องจากข้อจำกัดของคอมไพเลอร์ HIPCC และพฤติกรรม I/O ที่ไม่เปิดเผย จึงต้องให้ความสำคัญกับการออกแบบ การจัดวางข้อมูลและแพตเทิร์น swizzle
- การจัดตารางภายในโปรเซสเซอร์: AMD ต้องอาศัย ไฟล์รีจิสเตอร์และชุดคำสั่งเมทริกซ์ขนาดเล็ก แทน shared memory
- การจัดตารางระหว่างโปรเซสเซอร์: โครงสร้างแบบชิปเล็ตทำให้จำเป็นต้องกระจายงานโดยคำนึงถึง ผลของ NUMA ในระดับแคช
แพตเทิร์นการเข้าถึงหน่วยความจำของ HipKittens
- HipKittens(HK) ใช้ tile เป็นหน่วยข้อมูลพื้นฐาน และมี ฟังก์ชันโอเปอเรชันคล้าย PyTorch
- tile ถูกนิยามด้วย ชนิดข้อมูล ขนาด และเลย์เอาต์ และรองรับอินพุตที่หลากหลายด้วย C++ template metaprogramming
- การจัดตารางรีจิสเตอร์: HIPCC ไม่สามารถใช้รีจิสเตอร์บางตัวเป็นอินพุตของ MFMA ได้ ดังนั้น HK จึงมี ฟังก์ชันตรึงรีจิสเตอร์แบบชัดเจน
- นักพัฒนาสามารถกำหนดรีจิสเตอร์ด้วยตนเองเพื่อเขียน เคอร์เนลที่ให้ประสิทธิภาพสูงสุด
- เลย์เอาต์ของรีจิสเตอร์: AMD มีเลย์เอาต์ที่ต่างกันตามชนิดข้อมูลและรูปแบบเมทริกซ์ จึง ไม่สามารถใช้แพตเทิร์น swizzle แบบเดียวได้
- ตัวอย่างเช่น tile bf16 ขนาด 16×16 และ tile bf16 ขนาด 16×32 ต้องใช้แพตเทิร์น swizzle คนละแบบ
- โครงสร้างเฟสของคำสั่ง: คำสั่ง shared memory ของ AMD มี กลุ่มเฟสที่ไม่ต่อเนื่องกัน และ มีเอกสารภายในไม่เพียงพอ
- HK จึงมี solver ที่ย้อนวิศวกรรมขึ้นมา เพื่อจัดการเรื่องนี้
- การสร้างแอดเดรส: AMD รองรับ การโหลดแบบอะซิงโครนัสจาก HBM ไป shared memory และทำ optimization ด้วย HBM address swizzle
การจัดตารางภายในโปรเซสเซอร์: แพตเทิร์นของ wave
- Wave specialization มีประสิทธิภาพบน NVIDIA แต่สำหรับ AMD กลับทำให้ประสิทธิภาพลดลงเพราะ ไม่มีการจัดสรรรีจิสเตอร์ใหม่
- wave ฝั่ง producer ยึดรีจิสเตอร์ที่ไม่จำเป็นไว้ ขณะที่ wave ฝั่ง consumer มีรีจิสเตอร์ไม่พอจนเกิด spill
- จากผลทดลองของ HK, wave specialization บน AMD ทำให้เกิด ความเข้มข้นเชิงคณิตศาสตร์ลดลงและคอขวดด้านหน่วยความจำ
- ตัวอย่าง: ใน GEMM การจัดแบบ HK 0/8 ได้ 1605 TFLOPs ส่วน CUTLASS ได้ 1570 TFLOPs
- แพตเทิร์นการจัดตารางทางเลือก
- 8-wave ping-pong: สอง wave สลับกันรัน คลัสเตอร์หน่วยความจำ/การคำนวณ
- 4-wave interleave: หนึ่ง wave สลับการทำงานของ หน่วยความจำและการคำนวณ อย่างละเอียด
- 8-wave เขียนโค้ดได้กระชับกว่า ส่วน 4-wave ละเอียดกว่าแต่โค้ดยาวขึ้น
- ใน GEMM และ Attention Forward, 8-wave ทำประสิทธิภาพได้ในระดับ SoTA
การจัดตารางระหว่างโปรเซสเซอร์: แนวทางแบบรับรู้ชิปเล็ต
- AMD MI355X มี ชิปเล็ต XCD 8 ตัว และแต่ละตัวมี แคช L2 แยกอิสระ
- thread block ถูกจัดสรรให้ชิปเล็ตแบบ round-robin ทำให้ ลำดับของ grid ส่งผลโดยตรงต่อประสิทธิภาพการใช้แคชซ้ำ
- การจัดวางแบบ row-major อย่างง่ายทำให้อัตราการใช้แคช L2 ซ้ำต่ำและเกิด การสูญเสียแบนด์วิดท์
- ตัวอย่าง: L2 55%, LLC 95%, 15.1 TB/s, 1113 TFLOPs
- HK จึงนำ การจัดตารางแบบรับรู้ชิปเล็ต (grid) มาใช้ เพื่อใช้ประโยชน์จาก locality ของแคช L2 และ LLC พร้อมกัน
- โดยจัดกลุ่ม thread block ตาม พื้นที่ใกล้เคียงกันของเมทริกซ์ผลลัพธ์ เพื่อเพิ่มการนำข้อมูลอินพุตกลับมาใช้ซ้ำให้สูงสุด
ตัวอย่างเคอร์เนลจริง
- hot loop ของเคอร์เนล Attention Forward และ BF16 GEMM ใช้ ตาราง 8-wave ping-pong ของ HK
- แต่ละลูปจะสลับรัน คลัสเตอร์ Compute–Memory และซิงก์กันด้วย schedule barrier
- ในตัวอย่างโค้ดมีการใช้โอเปอเรชันของ HK เช่น mma_AtB, load, exp2, col_sum ซ้ำ ๆ
บทสรุป: AMD ในยุค Multi-silicon AI
- HipKittens ทำประสิทธิภาพได้ แข่งขันได้ บน AMD CDNA3 และ CDNA4
- มีแกนหลัก 3 อย่างคือ การเข้าถึงหน่วยความจำที่ปรับแต่งแล้ว, การจัดตาราง wave ที่ออกแบบเพื่อ AMD, และ การจัดตาราง grid แบบรับรู้ชิปเล็ต
- เคอร์เนลของ HK ทำ ประสิทธิภาพสูงสุดในฝั่ง AMD และยังแข่งขันได้กับ เคอร์เนลของ NVIDIA Blackwell
- เพื่อความหลากหลายของ AI computing จำเป็นต้อง เพิ่มการเข้าถึง AMD GPU และ HipKittens ก็เป็น ฐานซอฟต์แวร์สำคัญ สำหรับเป้าหมายนั้น
- การปรับปรุง HIPCC register scheduling ของ AMD ถูกชี้ว่าเป็นพื้นที่พัฒนาสำคัญในอนาคต
1 ความคิดเห็น
ความคิดเห็นจาก Hacker News