-
ความสำคัญของ Attention
- Attention เป็นเลเยอร์แกนหลักของสถาปัตยกรรม Transformer และเป็นคอขวดในโมเดลภาษาขนาดใหญ่และแอปพลิเคชันที่ใช้บริบทยาว
- FlashAttention และ FlashAttention-2 เป็นผู้บุกเบิกแนวทางเร่งความเร็ว Attention โดยลดการอ่าน/เขียนหน่วยความจำบน GPU ให้น้อยที่สุด
- ส่งผลให้ความยาวบริบทของ LLM เพิ่มขึ้นอย่างมาก
-
เทคโนโลยีหลักของ FlashAttention-3
- การใช้ความไม่ซิงโครนัส: ใช้ประโยชน์จากความไม่ซิงโครนัสของ Tensor Cores และ TMA เพื่อซ้อนทับการคำนวณทั้งหมดกับการเคลื่อนย้ายข้อมูล
- การประมวลผลแบบบล็อก: สลับทำการคูณเมทริกซ์และการคำนวณ softmax ในระดับบล็อก
- การประมวลผลความแม่นยำต่ำ: ใช้การรองรับความแม่นยำต่ำแบบ FP8 เพื่อเพิ่มประสิทธิภาพ
-
การเพิ่มประสิทธิภาพของ FlashAttention-3
- ประสิทธิภาพการใช้ GPU: ใช้สมรรถนะสูงสุดของ H100 GPU ได้ถึง 75% และเร็วกว่าเวอร์ชันก่อนหน้า 1.5-2 เท่า
- ประสิทธิภาพของความแม่นยำต่ำ: ใช้ FP8 เพื่อเพิ่มความเร็วในการประมวลผลและลดการใช้หน่วยความจำ
- การจัดการบริบทยาว: เร่งกลไก Attention เพื่อให้ประมวลผลข้อความที่ยาวขึ้นได้อย่างมีประสิทธิภาพ
-
สรุป FlashAttention
- FlashAttention จัดเรียงการคำนวณ Attention ใหม่ และใช้ tiling กับการคำนวณซ้ำเพื่อเพิ่มความเร็วอย่างมากและลดการใช้หน่วยความจำ
- ผ่านการทำ tiling จะโหลดบล็อกอินพุต ทำ Attention กับบล็อกนั้น แล้วอัปเดตเอาต์พุต
- ลดปริมาณการอ่าน/เขียนหน่วยความจำด้วยการไม่เขียนเมทริกซ์ Attention ระหว่างทางลงหน่วยความจำ
-
ฟีเจอร์ฮาร์ดแวร์ใหม่ของ Hopper GPU
- WGMMA: ใช้ Tensor Cores แบบใหม่เพื่อให้ได้ throughput สูง
- TMA: ยูนิตฮาร์ดแวร์ที่ช่วยเร่งการถ่ายโอนข้อมูลระหว่าง global memory และ shared memory
- FP8 ความแม่นยำต่ำ: ใช้ FP8 เพื่อเพิ่ม throughput ของ Tensor Core เป็นสองเท่า
-
ความไม่ซิงโครนัส: การซ้อนทับ GEMM และ Softmax
- ความจำเป็นของการซ้อนทับ: ทำ GEMM และ softmax แบบขนานเพื่อดึงประสิทธิภาพสูงสุด
- การจัดตารางแบบ ping-pong: ใช้สอง warp group สลับกันทำ GEMM และ softmax เพื่อเพิ่มประสิทธิภาพ
- การซ้อนทับภายใน warp group: ทำ GEMM และ softmax แบบขนานภายใน warp group เดียวกันเพื่อเพิ่ม throughput
-
ความแม่นยำต่ำ: ลดข้อผิดพลาดจากการควอนไทซ์ด้วยการประมวลผลแบบ incoherent
- การประมวลผลแบบ incoherent: ใช้การแปลง Hadamard เพื่อลดข้อผิดพลาดจากการควอนไทซ์
- ผลการทดลอง: การประมวลผลแบบ incoherent ช่วยลดข้อผิดพลาดจากการควอนไทซ์ได้ 2.6 เท่า
-
เบนช์มาร์กของ Attention
- FP16: เร็วกว่า FlashAttention-2 ประมาณ 1.6-1.8 เท่า
- FP8: ทำได้สูงสุดถึง 1.2 PFLOPS
สรุปของ GN⁺
- FlashAttention-3 ใช้ฟีเจอร์ฮาร์ดแวร์ใหม่ของ GPU เพื่อยกระดับประสิทธิภาพของกลไก Attention อย่างมาก
- สามารถจัดการบริบทยาวได้อย่างมีประสิทธิภาพ จึงดึงศักยภาพของโมเดลภาษาขนาดใหญ่ได้สูงสุด
- มีโอกาสสูงที่จะถูกรวมเข้ากับเฟรมเวิร์กหลักอย่าง PyTorch และจะส่งผลอย่างมากต่อการวิจัยและการใช้งาน AI ในอนาคต
- โครงการที่มีฟังก์ชันคล้ายกัน ได้แก่ Triton และ cuDNN
1 ความคิดเห็น
ความคิดเห็นจาก Hacker News
ดูเหมือนว่า Tri Dao จะเริ่มทำงานกับ FA3 ตั้งแต่เดือนเมษายน 2022
สงสัยว่าอัลกอริทึม Flash Attention พึ่งพาฮาร์ดแวร์มากแค่ไหน
สงสัยว่าคอมไพเลอร์จะสามารถค้นหาการปรับแต่งประสิทธิภาพแบบ FlashAttention ได้ด้วยตัวเองหรือไม่
ผู้ที่ต้องการพอร์ตไปยัง ROCm/AMD MI300x ให้ติดต่อมา
TMA (Tensor Memory Accelerator) เป็นหน่วยฮาร์ดแวร์ที่ช่วยเร่งการถ่ายโอนข้อมูลระหว่าง global memory กับ shared memory
FlashAttention-3 ถูกปรับให้เหมาะกับ GPU Hopper (เช่น H100)
มีการกล่าวว่าฟังก์ชันกระตุ้นอย่าง sigmoid ช้ามากใน LLM สมัยใหม่
สงสัยว่าทำไม Flash Attention ถึงช้ากว่าเดิม 5 เท่าเมื่อมี variable masking เทียบกับกรณีที่ไม่มี
สงสัยว่า FlashAttention สามารถมาแทนที่การคำนวณ attention ของ LLM ได้หรือไม่
ต้องใช้ฮาร์ดแวร์ราคาแพง