Felafax BlogTune Llama3 405B on AMD MI300x (เส้นทางของเรา)
บทนำ
- เมื่อโมเดลโอเพนซอร์สมีขนาดใหญ่ขึ้น ความต้องการโครงสร้างพื้นฐานที่ทรงพลังสำหรับรองรับการฝึก AI ขนาดใหญ่ก็เพิ่มขึ้น
- Felafax ได้ปรับแต่งโมเดล LLaMA 3.1 405B แบบละเอียดบน AMD GPU เพื่อพิสูจน์ประสิทธิภาพของฮาร์ดแวร์ AMD
- งานทั้งหมดถูกเผยแพร่เป็นโอเพนซอร์สบน GitHub
- AMD MI300X GPU ให้ประสิทธิภาพสูงเมื่อเทียบกับฮาร์ดแวร์ AI ของ NVIDIA
- โปรเจกต์นี้เกิดขึ้นได้ด้วยการสนับสนุนจาก TensorWave
JAX คืออะไร และทำไมจึงเลือกใช้
- JAX เป็นไลบรารีแมชชีนเลิร์นนิงทรงพลังที่ผสาน API คล้าย NumPy, automatic differentiation และคอมไพเลอร์ XLA ของ Google
- มี API ที่ยอดเยี่ยมสำหรับการประมวลผลแบบ model parallel จึงเหมาะอย่างยิ่งกับการฝึกโมเดลขนาดใหญ่
ข้อดีของ JAX
- pure functions: JAX สนับสนุนให้เขียน pure functions ทำให้โค้ดจัดองค์ประกอบได้ง่าย ดีบักง่าย และอ่านง่าย
- การประมวลผลแบบขนานขั้นสูง: JIT API ที่ยืดหยุ่นของ JAX รองรับ data parallel และ model parallel ขั้นสูง ซึ่งจำเป็นต่อการฝึกขนาดใหญ่
- codebase ที่สะอาด: ปรัชญาการออกแบบของ JAX ส่งเสริมการเขียนโค้ดที่พกพาข้ามแพลตฟอร์มฮาร์ดแวร์ได้
เหตุใด JAX จึงโดดเด่นบนฮาร์ดแวร์ที่ไม่ใช่ NVIDIA
- แนวทางที่เป็นอิสระจากฮาร์ดแวร์: JAX ใช้คอมไพเลอร์ XLA เพื่อคอมไพล์การคำนวณเป็นตัวแทนกลางที่ไม่ผูกกับฮาร์ดแวร์
- การเพิ่มประสิทธิภาพที่เป็นอิสระจากแพลตฟอร์ม: คอมไพเลอร์ XLA ดำเนินการเพิ่มประสิทธิภาพโดยไม่ขึ้นกับฮาร์ดแวร์
- การย้ายระบบที่ง่าย: เมื่อใช้ JAX การเปลี่ยนจาก NVIDIA ไปเป็น AMD ต้องแก้โค้ดเพียงเล็กน้อย
การตั้งค่า JAX บน AMD GPU
- ดึง Docker image มาใช้งาน เริ่มต้นคอนเทนเนอร์ แล้วตรวจสอบการติดตั้ง
- ใช้ AMD MI300x GPU จำนวน 8 ตัวเพื่อฝึกโมเดล LLaMA 405B
การฝึก LLaMA 405B: ประสิทธิภาพและการขยายขนาด
- ใช้ JAX เพื่อฝึกโมเดล LLaMA 405B บน AMD GPU
- ปรับแต่งแบบ LoRA โดยตั้งค่า both model weights และพารามิเตอร์ LoRA ด้วยความละเอียดแบบ bfloat16
- ขนาดโมเดล: ใช้ VRAM ราว 800GB
- น้ำหนัก LoRA และสถานะออปติไมเซอร์: ใช้ VRAM ราว 400GB
- การใช้ VRAM รวม: ราว 1200GB
- ความเร็วในการฝึก: ประมาณ 35 โทเคนต่อวินาที
- ประสิทธิภาพด้านหน่วยความจำ: รักษาไว้ได้ราว 70%
- การขยายขนาด: เมื่อใช้ JAX สามารถขยายบน GPU 8 ตัวได้เกือบเป็นเชิงเส้น
การตั้งค่าการฝึกของเรา
- แปลง LLaMA 3.1 จาก PyTorch ไปเป็น JAX
- กระจายโมเดลอย่างมีประสิทธิภาพผ่านการโหลดโมเดลและการ sharding ของพารามิเตอร์
การ sharding พารามิเตอร์ใน JAX
- ใช้ความสามารถ device mesh ของ JAX เพื่อกระจายโมเดลไปยัง AMD GPU 8 ตัวอย่างมีประสิทธิภาพ
- กำหนดกฎการ sharding ของพารามิเตอร์เพื่อ shard มิติของแต่ละเทนเซอร์ตามแกนของ mesh
การใช้งานการฝึก LoRA
- LoRA ลดจำนวนพารามิเตอร์ที่ต้องฝึกด้วยการแยกการอัปเดตน้ำหนักออกเป็นเมทริกซ์ low-rank
- มีการสร้างเลเยอร์ LoRADense เพื่อรวมพารามิเตอร์ LoRA เข้าไป
- กระจายพารามิเตอร์ LoRA อย่างมีประสิทธิภาพเพื่อเพิ่มประสิทธิภาพการใช้หน่วยความจำและการคำนวณ
บทสรุป
- ประสบการณ์ในการปรับแต่งโมเดล LLaMA 3.1 405B แบบละเอียดด้วย AMD GPU และ JAX เป็นไปในทางบวกอย่างมาก
- ใช้ประโยชน์จากความสามารถด้านการประมวลผลแบบขนานอันทรงพลังของ JAX และแนวทางที่เป็นอิสระจากฮาร์ดแวร์เพื่อกระจายโมเดลอย่างมีประสิทธิภาพ
- พิสูจน์ให้เห็นว่า AMD GPU เป็นทางเลือกที่ทรงพลังสำหรับการฝึก AI ขนาดใหญ่
- สามารถดูโค้ดทั้งหมดและรันได้ด้วยตนเองจาก GitHub repository
สรุปโดย GN⁺
- บทความนี้อธิบายวิธีฝึกโมเดล AI ขนาดใหญ่ให้มีประสิทธิภาพด้วย AMD GPU และ JAX
- เน้นว่า AMD hardware เป็นทางเลือกที่คุ้มค่ากว่าเมื่อเทียบกับ NVIDIA
- แนวทางที่เป็นอิสระจากฮาร์ดแวร์ของ JAX ช่วยเพิ่มความสามารถในการพกพาโค้ดและทำให้ดูแลรักษาได้ง่ายขึ้น
- ให้ข้อมูลที่มีประโยชน์และโค้ดสำหรับลงมือทำแก่ผู้ที่สนใจการฝึกโมเดลขนาดใหญ่
- โปรเจกต์ที่มีความสามารถคล้ายกัน ได้แก่ CUDA และ PyTorch ของ NVIDIA
1 ความคิดเห็น
ความคิดเห็นบน Hacker News
แชร์ผลงานการทำ fine-tune โมเดล Llama3.1 405B บน GPU 8xAMD MI300x โดยใช้ JAX
มีข้อเสนอให้สำรวจวิธีเอาชนะข้อจำกัดด้านหน่วยความจำและรันเวอร์ชันที่คอมไพล์ด้วย JIT
แชร์ประสบการณ์เกี่ยวกับ AMD GPU และการรองรับ ROCm
แชร์ประสบการณ์ทดลองในด้าน inference ของโมเดล 405B
torch.cudaก็ไม่ได้แย่ขนาดนั้นrocm:pytorchก็ง่ายพอ ๆ กับการใช้คอนเทนเนอร์rocm:jaxตั้งคำถามถึงการไม่มีข้อมูลประสิทธิภาพ
สงสัยว่าทำไม Obsidian (แอปจดโน้ต) ถึงมาทำเรื่องนี้
ขอให้ @dang ใส่ชื่อผู้ใช้ไว้ใน URL