7 คะแนน โดย GN⁺ 2025-02-07 | 1 ความคิดเห็น | แชร์ทาง WhatsApp
  • แม้การเพิ่มประสิทธิภาพของดีปเลิร์นนิงในสเกลขนาดใหญ่จะดูเหมือน “เล่นแร่แปรธาตุ” แต่ในความเป็นจริงสามารถเพิ่มประสิทธิภาพของโมเดลได้ด้วยหลักการง่าย ๆ ที่เข้าใจได้
  • ตั้งแต่ตัวเร่งความเร็วเพียงตัวเดียวไปจนถึงตัวเร่งความเร็วนับหมื่น หลักการที่ค่อนข้างเรียบง่ายแบบเดียวกันนี้ใช้ได้ทุกที่ และเมื่อเข้าใจแล้วก็จะช่วยให้ทำสิ่งที่มีประโยชน์ต่อไปนี้ได้:
    • ประเมินคร่าว ๆ ได้ว่าส่วนต่าง ๆ ของโมเดลเข้าใกล้ค่าที่เหมาะสมตามทฤษฎีมากน้อยเพียงใด
    • มีหลักเกณฑ์สำหรับเลือกเทคนิคการทำงานขนานหลายแบบในสเกลที่แตกต่างกัน
    • ประมาณต้นทุนและเวลาที่ต้องใช้ในการฝึกและรันโมเดล Transformer ขนาดใหญ่
    • ออกแบบอัลกอริทึมที่ใช้ประโยชน์จากลักษณะเฉพาะของฮาร์ดแวร์
    • ออกแบบฮาร์ดแวร์โดยเข้าใจขีดจำกัดของประสิทธิภาพอัลกอริทึมในปัจจุบันอย่างชัดเจน
  • ความรู้พื้นฐานที่จำเป็น
    • ควรเข้าใจแนวคิดพื้นฐานเกี่ยวกับ LLM และสถาปัตยกรรม Transformer
    • ไม่จำเป็นต้องเข้าใจการดำเนินงานในสเกลใหญ่ก็ได้
    • หากมีความรู้พื้นฐานเรื่องการฝึก LLM และมีประสบการณ์ใช้ JAX จะยิ่งดี
    • แนะนำให้อ้างอิงบล็อกโพสต์เกี่ยวกับสถาปัตยกรรม Transformer และสไลด์เกี่ยวกับการสเกล LLM ของ JAX
  • เป้าหมาย
    • พัฒนาความสามารถในการประเมินว่าโมเดลควรถูกทำให้ขนานบนฮาร์ดแวร์ที่มีอยู่อย่างไรจึงจะเหมาะสม
    • พัฒนาความสามารถในการคำนวณคร่าว ๆ ถึงเวลาและต้นทุนที่ใช้ในการฝึกและการอนุมาน

ทำไมจึงควรสนใจ

  • เมื่อ 3–4 ปีก่อน นักวิจัย ML ส่วนใหญ่ยังไม่จำเป็นต้องรู้เรื่องการเพิ่มประสิทธิภาพสเกลขนาดใหญ่เช่นนี้อย่างลึกซึ้ง
    • แต่ปัจจุบัน แม้แต่โมเดลที่ “เล็ก” ก็ทำงานใกล้ขีดจำกัดของฮาร์ดแวร์แล้ว ทำให้การเข้าใจวิธีทำงานขนาดใหญ่อย่างมีประสิทธิภาพกลายเป็นเรื่องจำเป็น
    • ประวัติศาสตร์ของ ML อาจมองได้ว่าเป็นกระแสที่นวัตกรรมด้านระบบและการปรับปรุงซอฟต์แวร์พัฒนาไปพร้อมกัน
    • เมื่อไม่นานมานี้ โมเดล Transformer ใช้ฮาร์ดแวร์ได้ถึงขีดจำกัด ทำให้หากไม่เข้าใจประสิทธิภาพของโมเดล สถาปัตยกรรมใหม่หรืองานวิจัยใหม่ก็มีโอกาสสูงที่จะล้มเหลวเมื่อนำไปใช้จริง
    • แม้จะได้ประสิทธิภาพดีขึ้น 20% บน benchmark แต่ถ้าประสิทธิภาพของฮาร์ดแวร์ลดลง 20% สุดท้ายแล้วความใช้งานจริงก็ยังต่ำอยู่ดี
  • เป้าหมายสำคัญของการสเกลโมเดลคือทำให้ throughput เพิ่มขึ้นแบบเชิงเส้นเมื่อเพิ่มจำนวนชิป (ตัวเร่งความเร็ว)
    • สิ่งนี้เรียกว่า "strong scaling"
    • การเพิ่มชิปช่วยลดเวลาในการคำนวณ แต่ก็มีต้นทุนในการสื่อสารระหว่างชิป
    • หากการสื่อสารใช้เวลานานกว่าการคำนวณ จะเข้าสู่สถานะ "communication bound" และไม่สามารถทำ strong scaling ได้
    • หากเข้าใจฮาร์ดแวร์ได้ดีพอจนคาดการณ์ได้ว่าคอขวดเหล่านี้จะเกิดขึ้นตรงไหน ก็สามารถออกแบบหรือปรับโครงสร้างโมเดลเพื่อหลีกเลี่ยงได้
  • เป้าหมายของหนังสือเล่มนี้คือ อธิบายว่าฮาร์ดแวร์ TPU (รวมถึง GPU) ทำงานอย่างไร และสถาปัตยกรรม Transformer พัฒนามาอย่างไรจนทำงานได้ดีกับฮาร์ดแวร์ในปัจจุบัน
    • ผู้เขียนหวังว่าจะเป็นประโยชน์ทั้งต่อนักวิจัยที่ออกแบบสถาปัตยกรรมใหม่ และวิศวกรที่พยายามทำให้ LLM รุ่นปัจจุบันทำงานได้รวดเร็ว

ภาพรวมทั้งหมด

  • บทความนี้ประกอบด้วยส่วนต่าง ๆ ดังนี้
  • ส่วนที่ 1 อธิบายปัจจัยที่กำหนดขีดจำกัดของประสิทธิภาพโมเดล (การสื่อสาร การคำนวณ หน่วยความจำ) ผ่านการวิเคราะห์ roofline
  • ส่วนที่ 2, ส่วนที่ 3 กล่าวถึงโครงสร้างภายในของ TPU และ GPU รวมถึงวิธีเชื่อมต่อระหว่างชิป
    • ซึ่งช่วยตอบคำถามต่อไปนี้
      • ตามทฤษฎีแล้ว การคูณเมทริกซ์ขนาดหนึ่ง ๆ สามารถทำได้เร็วแค่ไหน
      • ณ จุดใดการคำนวณจะถูกจำกัดด้วยแบนด์วิดท์หน่วยความจำหรือแบนด์วิดท์การสื่อสาร
      • คลัสเตอร์ TPU เชื่อมต่อกันด้วยโครงสร้างแบบใด และโดยคร่าว ๆ ต้องใช้เวลานานเท่าไรในการย้ายข้อมูลจากชิปหนึ่งไปยังอีกชิปหนึ่ง
      • จะคูณเมทริกซ์แบบกระจายอย่างมีประสิทธิภาพได้อย่างไร
  • ส่วนที่ 4 ลงรายละเอียดเกี่ยวกับสูตรของสถาปัตยกรรม Transformer (ขนาดเมทริกซ์ จำนวนพารามิเตอร์ FLOPs)
  • ส่วนที่ 5 และ ส่วนที่ 7 คือแกนหลัก โดยแนะนำวิธีต่าง ๆ ในการทำโมเดลให้ขนานบนหลายชิป
    • Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
    • รวมถึงเทคนิคประหยัดหน่วยความจำ เช่น ZeRO, Rematerialisation, Host offload, Gradient accumulation
  • ส่วนที่ 6, ส่วนที่ 8 ใช้กรณีตัวอย่างการฝึกและการอนุมานโมเดล LLaMA-3 บน TPU เพื่อแสดงต้นทุน เวลา และรูปแบบการจัดวางจริง
  • สุดท้าย ส่วนที่ 9, ส่วนที่ 10 กล่าวถึงวิธีโปรไฟล์โมเดล ดีบัก และใช้การประมวลผลแบบขนานใน JAX ในทางปฏิบัติ

รายละเอียดเพิ่มเติม: สรุปส่วนสำคัญของหนังสือ

1 ความคิดเห็น

 
GN⁺ 2025-02-07
ความคิดเห็นบน Hacker News
  • มีความคาดหวังว่า JAX จะมาแทนที่ pytorch/cuda ในอีกไม่กี่ปีข้างหน้า ปัญหา PTX กับทีม Deepseek แสดงให้เห็นถึงคุณค่าของการลงทุนกับแนวทางที่ลงไปในระดับล่างกว่าเพื่อดึงประสิทธิภาพฮาร์ดแวร์ออกมาให้ได้สูงสุด
    • เคยถูกใช้ภายใน Google เป็นคู่มือสำหรับงานด้านประสิทธิภาพ น่าแปลกใจที่ถูกเผยแพร่ออกสู่สาธารณะ แต่ดูเหมือนว่ารายละเอียดเกี่ยวกับ Gemini จะถูกตัดออกไป
    • จุดที่ดีของไกด์นี้คือสามารถย้ายไปใช้ GPU ได้โดยตรงด้วย JAX/XLA
    • มีความเห็นที่สงสัยว่าทำไม JAX ถึงใช้การ tracing แทน AST
    • มีการแชร์ลิงก์ไปยังเธรดทวีตของผู้เขียน
    • มีคนกำลังหาวิธีแปลงเว็บไซต์ Jekyll เป็น PDF
    • มีคำชื่นชมว่าเป็นบทความที่ยอดเยี่ยมและแสดงความขอบคุณ
    • มีความเห็นที่สงสัยว่าทำแอนิเมชันสวย ๆ แบบนี้ได้อย่างไร