DeepGEMM: เคอร์เนล FP8 GEMM ที่สะอาดและมีประสิทธิภาพด้วยการสเกลแบบละเอียด
(github.com/deepseek-ai)DeepGEMM
DeepGEMM เป็นไลบรารีสำหรับการคูณเมทริกซ์ทั่วไปแบบ FP8 (GEMM) ที่รองรับการสเกลแบบละเอียดซึ่งถูกนำเสนอใน DeepSeek-V3 ไลบรารีนี้รองรับทั้ง GEMM แบบทั่วไปและ GEMM แบบจัดกลุ่มสำหรับ Mix-of-Experts (MoE) เขียนด้วย CUDA และไม่ต้องคอมไพล์ระหว่างการติดตั้ง รองรับ NVIDIA Hopper tensor core และใช้การสะสมผล 2 ขั้นด้วย CUDA core เพื่อแก้ปัญหาความไม่แม่นยำของการสะสมผลบน FP8 tensor core อาศัยแนวคิดบางส่วนจาก CUTLASS และ CuTe แต่ลดการพึ่งพาเทมเพลตหรือพีชคณิตให้น้อยที่สุดเพื่อคงความเรียบง่ายไว้ ด้วยฟังก์ชันเคอร์เนลหลักเพียงตัวเดียวที่มีโค้ดราว 300 บรรทัด จึงเป็นแหล่งเรียนรู้ที่เหมาะสำหรับการศึกษาเมทริกซ์คูณ FP8 บน Hopper และเทคนิคการปรับแต่งประสิทธิภาพ แม้จะมีการออกแบบแบบน้ำหนักเบา แต่ก็ให้ประสิทธิภาพเทียบเท่าหรือดีกว่าไลบรารีที่ผู้เชี่ยวชาญปรับแต่งไว้ในรูปแบบเมทริกซ์ที่หลากหลาย
ประสิทธิภาพ
ทดสอบทุกรูปแบบที่สามารถใช้ในการอนุมาน DeepSeek-V3/R1 บน H800 SXM5 ด้วย NVCC 12.8 ตัวชี้วัดการเพิ่มความเร็วทั้งหมดคำนวณโดยเทียบกับอิมพลีเมนเทชันที่ปรับแต่งภายในซึ่งอ้างอิง CUTLASS 3.6 บางรูปแบบอาจยังให้ประสิทธิภาพไม่ดีนัก และยินดีรับ PR สำหรับการปรับแต่งเพิ่มเติม
GEMM ทั่วไป (โมเดลแบบหนาแน่น)
- จากการวัดประสิทธิภาพของ DeepGEMM ในขนาดเมทริกซ์ที่หลากหลาย พบว่าบางขนาดให้ความเร็วเพิ่มขึ้นสูงสุด 2.7 เท่า
GEMM แบบจัดกลุ่มสำหรับโมเดล MoE (เลย์เอาต์แบบต่อเนื่อง)
- ให้ความเร็วเพิ่มขึ้นสูงสุด 1.2 เท่า ขึ้นอยู่กับจำนวนกลุ่มและขนาดเมทริกซ์ของแต่ละกลุ่ม
GEMM แบบจัดกลุ่มสำหรับโมเดล MoE (เลย์เอาต์แบบมาสก์)
- ใช้เลย์เอาต์แบบมาสก์เพื่อให้ความเร็วเพิ่มขึ้นสูงสุด 1.2 เท่า
เริ่มต้นอย่างรวดเร็ว
ข้อกำหนด
- GPU สถาปัตยกรรม Hopper และต้องรองรับ
sm_90a - Python 3.8 ขึ้นไป
- CUDA 12.3 ขึ้นไป (แนะนำ 12.8 ขึ้นไปเพื่อประสิทธิภาพสูงสุด)
- PyTorch 2.1 ขึ้นไป
- CUTLASS 3.6 ขึ้นไป
การพัฒนา
- อธิบายกระบวนการพัฒนา รวมถึงการ clone submodule การสร้าง symbolic link การคอมไพล์แบบ JIT และการทดสอบอิมพลีเมนเทชัน GEMM ทั้งหมด
การติดตั้ง
- สามารถนำ
deep_gemmไปใช้ในโปรเจกต์ Python ได้
อินเทอร์เฟซ
ข้อควรระวัง
- ไลบรารีนี้มีเฉพาะเคอร์เนล GEMM และรองรับเฉพาะรูปแบบ NT เท่านั้น งานอย่างการ transpose หรือการ cast FP8 แบบอื่น ๆ ต้องนำไปพัฒนาแยกเอง
GEMM แบบหนาแน่นทั่วไป (ไม่จัดกลุ่ม)
- มีฟังก์ชันสำหรับรัน FP8 GEMM แบบพื้นฐานที่ไม่จัดกลุ่ม
GEMM แบบจัดกลุ่ม (เลย์เอาต์แบบต่อเนื่อง)
- ออกแบบมาสำหรับสถานการณ์ในโมเดล MoE ที่ expert ใช้รูปทรงเดียวกัน
GEMM แบบจัดกลุ่ม (เลย์เอาต์แบบมาสก์)
- ในขั้นตอน inference decoding จะส่งเทนเซอร์มาสก์มาเพื่อคำนวณเฉพาะส่วนที่มีผลเท่านั้น
ยูทิลิตี
- มีฟังก์ชันยูทิลิตีและตัวแปรสภาพแวดล้อมหลากหลายรายการเพื่อช่วยในการปรับแต่งประสิทธิภาพ
การปรับแต่งประสิทธิภาพ
การทำ warp specialization แบบต่อเนื่อง
- ทำตามแนวทางการออกแบบของ CUTLASS โดยซ้อนทับการย้ายข้อมูล คำสั่ง tensor core MMA และการยกระดับด้วย CUDA core เข้าด้วยกัน
ความสามารถ TMA ของ Hopper
- ใช้ TMA เพื่อเร่งการย้ายข้อมูล
การปรับแต่งรายละเอียดร่วมกัน
- ปรับปรุงประสิทธิภาพผ่านเทคนิคการปรับแต่งหลายรูปแบบ
ตัวจัดตารางบล็อกแบบรวมและปรับแต่งแล้ว
- มี scheduler สำหรับเคอร์เนลทั้งแบบไม่จัดกลุ่มและแบบจัดกลุ่มทั้งหมด
การออกแบบ JIT แบบสมบูรณ์
- ปรับปรุงประสิทธิภาพด้วยการออกแบบ JIT ที่ไม่ต้องคอมไพล์ระหว่างติดตั้ง
ขนาดบล็อกที่ไม่จัดแนว
- รองรับขนาดบล็อกที่ไม่จัดแนวเพื่อเพิ่มการใช้งาน SM ให้สูงสุดในบางรูปแบบ
FFMA SASS interleaving
- ปรับคำสั่ง FFMA เพื่อเพิ่ม parallelism ระดับ warp และยกระดับประสิทธิภาพ
คำขอบคุณ
- DeepGEMM ได้รับแรงบันดาลใจจากโปรเจกต์ CUTLASS และขอแสดงความขอบคุณและความเคารพต่อผู้พัฒนา
ไลเซนส์
- เผยแพร่ภายใต้สัญญาอนุญาต MIT
ยังไม่มีความคิดเห็น