วิธีสเกลโมเดลของคุณ: มุมมองเชิงระบบต่อ LLM บน TPU
(jax-ml.github.io)- แม้การเพิ่มประสิทธิภาพของดีปเลิร์นนิงในสเกลขนาดใหญ่จะดูเหมือน “เล่นแร่แปรธาตุ” แต่ในความเป็นจริงสามารถเพิ่มประสิทธิภาพของโมเดลได้ด้วยหลักการง่าย ๆ ที่เข้าใจได้
- ตั้งแต่ตัวเร่งความเร็วเพียงตัวเดียวไปจนถึงตัวเร่งความเร็วนับหมื่น หลักการที่ค่อนข้างเรียบง่ายแบบเดียวกันนี้ใช้ได้ทุกที่ และเมื่อเข้าใจแล้วก็จะช่วยให้ทำสิ่งที่มีประโยชน์ต่อไปนี้ได้:
- ประเมินคร่าว ๆ ได้ว่าส่วนต่าง ๆ ของโมเดลเข้าใกล้ค่าที่เหมาะสมตามทฤษฎีมากน้อยเพียงใด
- มีหลักเกณฑ์สำหรับเลือกเทคนิคการทำงานขนานหลายแบบในสเกลที่แตกต่างกัน
- ประมาณต้นทุนและเวลาที่ต้องใช้ในการฝึกและรันโมเดล Transformer ขนาดใหญ่
- ออกแบบอัลกอริทึมที่ใช้ประโยชน์จากลักษณะเฉพาะของฮาร์ดแวร์
- ออกแบบฮาร์ดแวร์โดยเข้าใจขีดจำกัดของประสิทธิภาพอัลกอริทึมในปัจจุบันอย่างชัดเจน
- ความรู้พื้นฐานที่จำเป็น
- ควรเข้าใจแนวคิดพื้นฐานเกี่ยวกับ LLM และสถาปัตยกรรม Transformer
- ไม่จำเป็นต้องเข้าใจการดำเนินงานในสเกลใหญ่ก็ได้
- หากมีความรู้พื้นฐานเรื่องการฝึก LLM และมีประสบการณ์ใช้ JAX จะยิ่งดี
- แนะนำให้อ้างอิงบล็อกโพสต์เกี่ยวกับสถาปัตยกรรม Transformer และสไลด์เกี่ยวกับการสเกล LLM ของ JAX
- เป้าหมาย
- พัฒนาความสามารถในการประเมินว่าโมเดลควรถูกทำให้ขนานบนฮาร์ดแวร์ที่มีอยู่อย่างไรจึงจะเหมาะสม
- พัฒนาความสามารถในการคำนวณคร่าว ๆ ถึงเวลาและต้นทุนที่ใช้ในการฝึกและการอนุมาน
ทำไมจึงควรสนใจ
- เมื่อ 3–4 ปีก่อน นักวิจัย ML ส่วนใหญ่ยังไม่จำเป็นต้องรู้เรื่องการเพิ่มประสิทธิภาพสเกลขนาดใหญ่เช่นนี้อย่างลึกซึ้ง
- แต่ปัจจุบัน แม้แต่โมเดลที่ “เล็ก” ก็ทำงานใกล้ขีดจำกัดของฮาร์ดแวร์แล้ว ทำให้การเข้าใจวิธีทำงานขนาดใหญ่อย่างมีประสิทธิภาพกลายเป็นเรื่องจำเป็น
- ประวัติศาสตร์ของ ML อาจมองได้ว่าเป็นกระแสที่นวัตกรรมด้านระบบและการปรับปรุงซอฟต์แวร์พัฒนาไปพร้อมกัน
- เมื่อไม่นานมานี้ โมเดล Transformer ใช้ฮาร์ดแวร์ได้ถึงขีดจำกัด ทำให้หากไม่เข้าใจประสิทธิภาพของโมเดล สถาปัตยกรรมใหม่หรืองานวิจัยใหม่ก็มีโอกาสสูงที่จะล้มเหลวเมื่อนำไปใช้จริง
- แม้จะได้ประสิทธิภาพดีขึ้น 20% บน benchmark แต่ถ้าประสิทธิภาพของฮาร์ดแวร์ลดลง 20% สุดท้ายแล้วความใช้งานจริงก็ยังต่ำอยู่ดี
- เป้าหมายสำคัญของการสเกลโมเดลคือทำให้ throughput เพิ่มขึ้นแบบเชิงเส้นเมื่อเพิ่มจำนวนชิป (ตัวเร่งความเร็ว)
- สิ่งนี้เรียกว่า "strong scaling"
- การเพิ่มชิปช่วยลดเวลาในการคำนวณ แต่ก็มีต้นทุนในการสื่อสารระหว่างชิป
- หากการสื่อสารใช้เวลานานกว่าการคำนวณ จะเข้าสู่สถานะ "communication bound" และไม่สามารถทำ strong scaling ได้
- หากเข้าใจฮาร์ดแวร์ได้ดีพอจนคาดการณ์ได้ว่าคอขวดเหล่านี้จะเกิดขึ้นตรงไหน ก็สามารถออกแบบหรือปรับโครงสร้างโมเดลเพื่อหลีกเลี่ยงได้
- เป้าหมายของหนังสือเล่มนี้คือ อธิบายว่าฮาร์ดแวร์ TPU (รวมถึง GPU) ทำงานอย่างไร และสถาปัตยกรรม Transformer พัฒนามาอย่างไรจนทำงานได้ดีกับฮาร์ดแวร์ในปัจจุบัน
- ผู้เขียนหวังว่าจะเป็นประโยชน์ทั้งต่อนักวิจัยที่ออกแบบสถาปัตยกรรมใหม่ และวิศวกรที่พยายามทำให้ LLM รุ่นปัจจุบันทำงานได้รวดเร็ว
ภาพรวมทั้งหมด
- บทความนี้ประกอบด้วยส่วนต่าง ๆ ดังนี้
- ส่วนที่ 1 อธิบายปัจจัยที่กำหนดขีดจำกัดของประสิทธิภาพโมเดล (การสื่อสาร การคำนวณ หน่วยความจำ) ผ่านการวิเคราะห์ roofline
- ส่วนที่ 2, ส่วนที่ 3 กล่าวถึงโครงสร้างภายในของ TPU และ GPU รวมถึงวิธีเชื่อมต่อระหว่างชิป
- ซึ่งช่วยตอบคำถามต่อไปนี้
- ตามทฤษฎีแล้ว การคูณเมทริกซ์ขนาดหนึ่ง ๆ สามารถทำได้เร็วแค่ไหน
- ณ จุดใดการคำนวณจะถูกจำกัดด้วยแบนด์วิดท์หน่วยความจำหรือแบนด์วิดท์การสื่อสาร
- คลัสเตอร์ TPU เชื่อมต่อกันด้วยโครงสร้างแบบใด และโดยคร่าว ๆ ต้องใช้เวลานานเท่าไรในการย้ายข้อมูลจากชิปหนึ่งไปยังอีกชิปหนึ่ง
- จะคูณเมทริกซ์แบบกระจายอย่างมีประสิทธิภาพได้อย่างไร
- ซึ่งช่วยตอบคำถามต่อไปนี้
- ส่วนที่ 4 ลงรายละเอียดเกี่ยวกับสูตรของสถาปัตยกรรม Transformer (ขนาดเมทริกซ์ จำนวนพารามิเตอร์ FLOPs)
- ส่วนที่ 5 และ ส่วนที่ 7 คือแกนหลัก โดยแนะนำวิธีต่าง ๆ ในการทำโมเดลให้ขนานบนหลายชิป
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- รวมถึงเทคนิคประหยัดหน่วยความจำ เช่น ZeRO, Rematerialisation, Host offload, Gradient accumulation
- ส่วนที่ 6, ส่วนที่ 8 ใช้กรณีตัวอย่างการฝึกและการอนุมานโมเดล LLaMA-3 บน TPU เพื่อแสดงต้นทุน เวลา และรูปแบบการจัดวางจริง
- สุดท้าย ส่วนที่ 9, ส่วนที่ 10 กล่าวถึงวิธีโปรไฟล์โมเดล ดีบัก และใช้การประมวลผลแบบขนานใน JAX ในทางปฏิบัติ
รายละเอียดเพิ่มเติม: สรุปส่วนสำคัญของหนังสือ
-
พาร์ต 1: Preliminaries
-
ส่วนที่ 1: บทนำสู่การวิเคราะห์ Roofline แบบง่าย
- ปัจจัย 3 ประการที่จำกัดอัลกอริทึม: การคำนวณ การสื่อสาร และหน่วยความจำ
- จากนั้นเรียนรู้วิธีประเมินขีดบนของความเร็วในการคำนวณ
-
- TPU ทำการคำนวณอย่างไร
- โครงสร้าง systolic array คืออะไร
- ทำความเข้าใจพื้นฐานว่า TPU ให้แบนด์วิดท์หน่วยความจำและการสื่อสารอย่างไร
-
ส่วนที่ 3: เมทริกซ์แบบกระจายและการคูณแบบกระจาย
- เทคนิคการเก็บพารามิเตอร์ของโมเดลโดยแบ่งกระจายไว้บนหลายชิป (Sharding)
- วิธีจัดการการสื่อสารและคอขวดที่เกิดขึ้นระหว่างการคำนวณเมทริกซ์แบบกระจาย
-
-
พาร์ต 2: Transformers
-
ส่วนที่ 4: รวมสูตร Transformer ที่จำเป็น
- การคูณเมทริกซ์ใน Transformer มีรูปแบบเฉพาะอย่างไร
- วิธีคำนวณจำนวนพารามิเตอร์ FLOPs ขนาดของ KV cache เป็นต้น
- ทำความเข้าใจว่า Attention ต้องใช้การคำนวณมากเพียงใดเมื่อเทียบกับบล็อก Feed-Forward
-
ส่วนที่ 5: กลยุทธ์การทำขนานสำหรับการฝึก Transformer
- แนะนำเทคนิค Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- แนวทางลดการใช้หน่วยความจำ เช่น ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload
- วางกรอบแนวคิดในการจัดการทำขนานให้เหมาะกับขนาดโมเดลและจำนวนชิปที่กำหนด
-
ส่วนที่ 6: การประยุกต์ฝึก LLaMA 3 บน TPU
- ประเมินเวลาและต้นทุนเมื่อสมมติว่าฝึกโมเดล LLaMA 3 บนสภาพแวดล้อม TPU จริง
- นำเสนอตัวอย่างที่เป็นรูปธรรมเกี่ยวกับ batch size วิธีทำขนาน การใช้หน่วยความจำ เป็นต้น
-
ส่วนที่ 7: ทุกอย่างเกี่ยวกับการอนุมานของ Transformer
- ในการอนุมาน ปัจจัยใหม่ที่สำคัญคือ latency
- ปัญหาการใช้หน่วยความจำและการสื่อสารจาก KV cache เป็นต้น
- การอภิปรายว่าจะจัดสรรและเชื่อมต่อหลายชิปอย่างไรเพื่อให้บริการโมเดล
-
ส่วนที่ 8: การประยุกต์เสิร์ฟ LLaMA 3 บน TPU
- วิเคราะห์ trade-off โดยคร่าว ๆ ระหว่างต้นทุน latency และ throughput เมื่อสมมติว่าเสิร์ฟ LLaMA 3 บน TPU v5e
-
-
พาร์ต 3: Practical Tutorials
-
ส่วนที่ 9: วิธีโปรไฟล์โค้ด TPU
- ทำความเข้าใจสแตก JAX+XLA
- ระบุปัญหาประสิทธิภาพตกในสถานการณ์จริงและแนวทางแก้ไข
- วิธีใช้ profiler ของ JAX/TensorBoard
-
ส่วนที่ 10: เขียนโปรแกรม TPU ด้วย JAX
- วิธีใช้ API สำหรับการทำขนานของ JAX (primitives)
- เรียนรู้แนวคิดการคำนวณแบบขนานผ่านตัวอย่างและแบบฝึกหัด
-
ส่วนที่ 11: บทสรุปและแหล่งข้อมูลเพิ่มเติม
- แหล่งอ่านเพิ่มเติมเกี่ยวกับ TPU และ LLM
- ปิดท้ายเนื้อหาโดยสรุปภาพรวมสั้น ๆ พร้อมกล่าวถึงแนวโน้มในอนาคต
-
1 ความคิดเห็น
ความคิดเห็นบน Hacker News