1 คะแนน โดย GN⁺ 2024-07-12 | 1 ความคิดเห็น | แชร์ทาง WhatsApp
  • ความสำคัญของ Attention

    • Attention เป็นเลเยอร์แกนหลักของสถาปัตยกรรม Transformer และเป็นคอขวดในโมเดลภาษาขนาดใหญ่และแอปพลิเคชันที่ใช้บริบทยาว
    • FlashAttention และ FlashAttention-2 เป็นผู้บุกเบิกแนวทางเร่งความเร็ว Attention โดยลดการอ่าน/เขียนหน่วยความจำบน GPU ให้น้อยที่สุด
    • ส่งผลให้ความยาวบริบทของ LLM เพิ่มขึ้นอย่างมาก
  • เทคโนโลยีหลักของ FlashAttention-3

    • การใช้ความไม่ซิงโครนัส: ใช้ประโยชน์จากความไม่ซิงโครนัสของ Tensor Cores และ TMA เพื่อซ้อนทับการคำนวณทั้งหมดกับการเคลื่อนย้ายข้อมูล
    • การประมวลผลแบบบล็อก: สลับทำการคูณเมทริกซ์และการคำนวณ softmax ในระดับบล็อก
    • การประมวลผลความแม่นยำต่ำ: ใช้การรองรับความแม่นยำต่ำแบบ FP8 เพื่อเพิ่มประสิทธิภาพ
  • การเพิ่มประสิทธิภาพของ FlashAttention-3

    • ประสิทธิภาพการใช้ GPU: ใช้สมรรถนะสูงสุดของ H100 GPU ได้ถึง 75% และเร็วกว่าเวอร์ชันก่อนหน้า 1.5-2 เท่า
    • ประสิทธิภาพของความแม่นยำต่ำ: ใช้ FP8 เพื่อเพิ่มความเร็วในการประมวลผลและลดการใช้หน่วยความจำ
    • การจัดการบริบทยาว: เร่งกลไก Attention เพื่อให้ประมวลผลข้อความที่ยาวขึ้นได้อย่างมีประสิทธิภาพ
  • สรุป FlashAttention

    • FlashAttention จัดเรียงการคำนวณ Attention ใหม่ และใช้ tiling กับการคำนวณซ้ำเพื่อเพิ่มความเร็วอย่างมากและลดการใช้หน่วยความจำ
    • ผ่านการทำ tiling จะโหลดบล็อกอินพุต ทำ Attention กับบล็อกนั้น แล้วอัปเดตเอาต์พุต
    • ลดปริมาณการอ่าน/เขียนหน่วยความจำด้วยการไม่เขียนเมทริกซ์ Attention ระหว่างทางลงหน่วยความจำ
  • ฟีเจอร์ฮาร์ดแวร์ใหม่ของ Hopper GPU

    • WGMMA: ใช้ Tensor Cores แบบใหม่เพื่อให้ได้ throughput สูง
    • TMA: ยูนิตฮาร์ดแวร์ที่ช่วยเร่งการถ่ายโอนข้อมูลระหว่าง global memory และ shared memory
    • FP8 ความแม่นยำต่ำ: ใช้ FP8 เพื่อเพิ่ม throughput ของ Tensor Core เป็นสองเท่า
  • ความไม่ซิงโครนัส: การซ้อนทับ GEMM และ Softmax

    • ความจำเป็นของการซ้อนทับ: ทำ GEMM และ softmax แบบขนานเพื่อดึงประสิทธิภาพสูงสุด
    • การจัดตารางแบบ ping-pong: ใช้สอง warp group สลับกันทำ GEMM และ softmax เพื่อเพิ่มประสิทธิภาพ
    • การซ้อนทับภายใน warp group: ทำ GEMM และ softmax แบบขนานภายใน warp group เดียวกันเพื่อเพิ่ม throughput
  • ความแม่นยำต่ำ: ลดข้อผิดพลาดจากการควอนไทซ์ด้วยการประมวลผลแบบ incoherent

    • การประมวลผลแบบ incoherent: ใช้การแปลง Hadamard เพื่อลดข้อผิดพลาดจากการควอนไทซ์
    • ผลการทดลอง: การประมวลผลแบบ incoherent ช่วยลดข้อผิดพลาดจากการควอนไทซ์ได้ 2.6 เท่า
  • เบนช์มาร์กของ Attention

    • FP16: เร็วกว่า FlashAttention-2 ประมาณ 1.6-1.8 เท่า
    • FP8: ทำได้สูงสุดถึง 1.2 PFLOPS

สรุปของ GN⁺

  • FlashAttention-3 ใช้ฟีเจอร์ฮาร์ดแวร์ใหม่ของ GPU เพื่อยกระดับประสิทธิภาพของกลไก Attention อย่างมาก
  • สามารถจัดการบริบทยาวได้อย่างมีประสิทธิภาพ จึงดึงศักยภาพของโมเดลภาษาขนาดใหญ่ได้สูงสุด
  • มีโอกาสสูงที่จะถูกรวมเข้ากับเฟรมเวิร์กหลักอย่าง PyTorch และจะส่งผลอย่างมากต่อการวิจัยและการใช้งาน AI ในอนาคต
  • โครงการที่มีฟังก์ชันคล้ายกัน ได้แก่ Triton และ cuDNN

1 ความคิดเห็น

 
GN⁺ 2024-07-12
ความคิดเห็นจาก Hacker News
  • ดูเหมือนว่า Tri Dao จะเริ่มทำงานกับ FA3 ตั้งแต่เดือนเมษายน 2022

    • เหตุผลที่โค้ดเพิ่งถูกเปิดเผยหลังจากการประกาศ Hopper/H100 ไปแล้ว 2 ปี อาจเป็นเพราะมีโซลูชันที่ดีกว่าพร้อมแล้ว
    • งานวิจัยช่วงหลังของ Tri มุ่งเน้นไปที่สถาปัตยกรรมสไตล์ SSM และ Mamba
    • Flash Attention มีความซับซ้อนเชิงเวลาระดับกำลังสองตามความยาวลำดับ แต่ขั้นตอนวิธีสมัยใหม่มีความซับซ้อนระดับกำลังสองหรือต่ำกว่า
    • Dao และ Gu ได้เผยแพร่บทความในปีนี้ที่ทำให้ Mamba/SSM สามารถรับการเร่งความเร็วด้วยฮาร์ดแวร์แบบเดียวกับ Transformer ได้
  • สงสัยว่าอัลกอริทึม Flash Attention พึ่งพาฮาร์ดแวร์มากแค่ไหน

    • มีการกล่าวถึงว่ามันใช้ประโยชน์จากความสามารถแบบ asynchronous ของ GPU H100
    • ไลบรารี Flash Attention ต้องใช้ CUDA แต่ดูเหมือนว่าจะมีการพอร์ตไปยัง Metal แล้ว
    • หากอัลกอริทึมเป็น pure function ก็น่าจะจินตนาการได้ว่าสามารถนำไปใช้งานบน GPU/ML framework ใดก็ได้
  • สงสัยว่าคอมไพเลอร์จะสามารถค้นหาการปรับแต่งประสิทธิภาพแบบ FlashAttention ได้ด้วยตัวเองหรือไม่

    • TVM และ tinygrad กำลังทำงานไปในทิศทางนั้น แต่ยังตั้งคำถามถึงความเป็นไปได้จริง
  • ผู้ที่ต้องการพอร์ตไปยัง ROCm/AMD MI300x ให้ติดต่อมา

    • ยินดีบริจาคเวลาในการประมวลผลให้
  • TMA (Tensor Memory Accelerator) เป็นหน่วยฮาร์ดแวร์ที่ช่วยเร่งการถ่ายโอนข้อมูลระหว่าง global memory กับ shared memory

    • ช่วยปลดปล่อยรีจิสเตอร์ ทำให้เพิ่มขนาดไทล์และประสิทธิภาพได้
  • FlashAttention-3 ถูกปรับให้เหมาะกับ GPU Hopper (เช่น H100)

    • สงสัยว่ามันทำงานอย่างไรบน GPU สำหรับผู้บริโภค (เช่น 3090, 4090)
  • มีการกล่าวว่าฟังก์ชันกระตุ้นอย่าง sigmoid ช้ามากใน LLM สมัยใหม่

    • มีการใช้ฟังก์ชันกระตุ้นอย่าง SiLU, Swish และ SOLU กันมาก
    • หาก Relu ทำให้ประสิทธิภาพลดลงน้อยกว่า การกลับไปใช้ Relu ก็อาจดีกว่า
  • สงสัยว่าทำไม Flash Attention ถึงช้ากว่าเดิม 5 เท่าเมื่อมี variable masking เทียบกับกรณีที่ไม่มี

    • การรองรับ masking ที่ดีไม่พอทำให้การปรับแต่งประสิทธิภาพแทบไร้ผล
  • สงสัยว่า FlashAttention สามารถมาแทนที่การคำนวณ attention ของ LLM ได้หรือไม่

    • สงสัยว่า LLM จำเป็นต้องถูกฝึกมาโดยเฉพาะเพื่อให้ใช้ FA ได้หรือไม่
    • สงสัยว่า FA เกี่ยวข้องอย่างไรกับกลยุทธ์อย่าง GQA (grouped query attention) หรือ sliding window attention
    • เมื่อ llama.cpp เพิ่มการรองรับ Flash Attention ก็สงสัยว่ามันเพียงแค่ใช้ CUDA kernel ที่ Flash Attention มีให้หรือไม่
    • เข้าใจได้ยากว่าการเปรียบเทียบ FlashAttention กับ Triton หมายถึงอะไร
  • ต้องใช้ฮาร์ดแวร์ราคาแพง