ปรับแต่ง Llama 405B แบบละเอียดด้วย AMD GPU
(publish.obsidian.md)Felafax BlogTune Llama3 405B on AMD MI300x (เส้นทางของเรา)
บทนำ
- เมื่อโมเดลโอเพนซอร์สมีขนาดใหญ่ขึ้น ความต้องการโครงสร้างพื้นฐานที่ทรงพลังสำหรับรองรับการฝึก AI ขนาดใหญ่ก็เพิ่มขึ้น
- Felafax ได้ปรับแต่งโมเดล LLaMA 3.1 405B แบบละเอียดบน AMD GPU เพื่อพิสูจน์ประสิทธิภาพของฮาร์ดแวร์ AMD
- งานทั้งหมดถูกเผยแพร่เป็นโอเพนซอร์สบน GitHub
- AMD MI300X GPU ให้ประสิทธิภาพสูงเมื่อเทียบกับฮาร์ดแวร์ AI ของ NVIDIA
- โปรเจกต์นี้เกิดขึ้นได้ด้วยการสนับสนุนจาก TensorWave
JAX คืออะไร และทำไมจึงเลือกใช้
- JAX เป็นไลบรารีแมชชีนเลิร์นนิงทรงพลังที่ผสาน API คล้าย NumPy, automatic differentiation และคอมไพเลอร์ XLA ของ Google
- มี API ที่ยอดเยี่ยมสำหรับการประมวลผลแบบ model parallel จึงเหมาะอย่างยิ่งกับการฝึกโมเดลขนาดใหญ่
ข้อดีของ JAX
- pure functions: JAX สนับสนุนให้เขียน pure functions ทำให้โค้ดจัดองค์ประกอบได้ง่าย ดีบักง่าย และอ่านง่าย
- การประมวลผลแบบขนานขั้นสูง: JIT API ที่ยืดหยุ่นของ JAX รองรับ data parallel และ model parallel ขั้นสูง ซึ่งจำเป็นต่อการฝึกขนาดใหญ่
- codebase ที่สะอาด: ปรัชญาการออกแบบของ JAX ส่งเสริมการเขียนโค้ดที่พกพาข้ามแพลตฟอร์มฮาร์ดแวร์ได้
เหตุใด JAX จึงโดดเด่นบนฮาร์ดแวร์ที่ไม่ใช่ NVIDIA
- แนวทางที่เป็นอิสระจากฮาร์ดแวร์: JAX ใช้คอมไพเลอร์ XLA เพื่อคอมไพล์การคำนวณเป็นตัวแทนกลางที่ไม่ผูกกับฮาร์ดแวร์
- การเพิ่มประสิทธิภาพที่เป็นอิสระจากแพลตฟอร์ม: คอมไพเลอร์ XLA ดำเนินการเพิ่มประสิทธิภาพโดยไม่ขึ้นกับฮาร์ดแวร์
- การย้ายระบบที่ง่าย: เมื่อใช้ JAX การเปลี่ยนจาก NVIDIA ไปเป็น AMD ต้องแก้โค้ดเพียงเล็กน้อย
การตั้งค่า JAX บน AMD GPU
- ดึง Docker image มาใช้งาน เริ่มต้นคอนเทนเนอร์ แล้วตรวจสอบการติดตั้ง
- ใช้ AMD MI300x GPU จำนวน 8 ตัวเพื่อฝึกโมเดล LLaMA 405B
การฝึก LLaMA 405B: ประสิทธิภาพและการขยายขนาด
- ใช้ JAX เพื่อฝึกโมเดล LLaMA 405B บน AMD GPU
- ปรับแต่งแบบ LoRA โดยตั้งค่า both model weights และพารามิเตอร์ LoRA ด้วยความละเอียดแบบ bfloat16
- ขนาดโมเดล: ใช้ VRAM ราว 800GB
- น้ำหนัก LoRA และสถานะออปติไมเซอร์: ใช้ VRAM ราว 400GB
- การใช้ VRAM รวม: ราว 1200GB
- ความเร็วในการฝึก: ประมาณ 35 โทเคนต่อวินาที
- ประสิทธิภาพด้านหน่วยความจำ: รักษาไว้ได้ราว 70%
- การขยายขนาด: เมื่อใช้ JAX สามารถขยายบน GPU 8 ตัวได้เกือบเป็นเชิงเส้น
การตั้งค่าการฝึกของเรา
- แปลง LLaMA 3.1 จาก PyTorch ไปเป็น JAX
- กระจายโมเดลอย่างมีประสิทธิภาพผ่านการโหลดโมเดลและการ sharding ของพารามิเตอร์
การ sharding พารามิเตอร์ใน JAX
- ใช้ความสามารถ device mesh ของ JAX เพื่อกระจายโมเดลไปยัง AMD GPU 8 ตัวอย่างมีประสิทธิภาพ
- กำหนดกฎการ sharding ของพารามิเตอร์เพื่อ shard มิติของแต่ละเทนเซอร์ตามแกนของ mesh
การใช้งานการฝึก LoRA
- LoRA ลดจำนวนพารามิเตอร์ที่ต้องฝึกด้วยการแยกการอัปเดตน้ำหนักออกเป็นเมทริกซ์ low-rank
- มีการสร้างเลเยอร์ LoRADense เพื่อรวมพารามิเตอร์ LoRA เข้าไป
- กระจายพารามิเตอร์ LoRA อย่างมีประสิทธิภาพเพื่อเพิ่มประสิทธิภาพการใช้หน่วยความจำและการคำนวณ
บทสรุป
- ประสบการณ์ในการปรับแต่งโมเดล LLaMA 3.1 405B แบบละเอียดด้วย AMD GPU และ JAX เป็นไปในทางบวกอย่างมาก
- ใช้ประโยชน์จากความสามารถด้านการประมวลผลแบบขนานอันทรงพลังของ JAX และแนวทางที่เป็นอิสระจากฮาร์ดแวร์เพื่อกระจายโมเดลอย่างมีประสิทธิภาพ
- พิสูจน์ให้เห็นว่า AMD GPU เป็นทางเลือกที่ทรงพลังสำหรับการฝึก AI ขนาดใหญ่
- สามารถดูโค้ดทั้งหมดและรันได้ด้วยตนเองจาก GitHub repository
สรุปโดย GN⁺
- บทความนี้อธิบายวิธีฝึกโมเดล AI ขนาดใหญ่ให้มีประสิทธิภาพด้วย AMD GPU และ JAX
- เน้นว่า AMD hardware เป็นทางเลือกที่คุ้มค่ากว่าเมื่อเทียบกับ NVIDIA
- แนวทางที่เป็นอิสระจากฮาร์ดแวร์ของ JAX ช่วยเพิ่มความสามารถในการพกพาโค้ดและทำให้ดูแลรักษาได้ง่ายขึ้น
- ให้ข้อมูลที่มีประโยชน์และโค้ดสำหรับลงมือทำแก่ผู้ที่สนใจการฝึกโมเดลขนาดใหญ่
- โปรเจกต์ที่มีความสามารถคล้ายกัน ได้แก่ CUDA และ PyTorch ของ NVIDIA
1 ความคิดเห็น
ความคิดเห็นจาก Hacker News
เมื่อไม่นานมานี้ได้ทำการ fine-tune โมเดล llama3.1 405B บน 8xAMD MI300x GPU โดยใช้ JAX แทน PyTorch
ด้วย API สำหรับ sharding ขั้นสูงของ JAX จึงได้ประสิทธิภาพที่ดี และได้สรุปเทคนิค sharding ที่ใช้ไว้ในบล็อกแล้ว นอกจากนี้ยังเปิดเผยโค้ดด้วย: https://github.com/felafax/felafax
เราเป็นสตาร์ตอัปเล็ก ๆ ที่สร้างโครงสร้างพื้นฐาน AI สำหรับการ fine-tune และ serving ของ LLM บนฮาร์ดแวร์ที่ไม่ใช่ NVIDIA เช่น TPU, AMD, Trainium
หลายบริษัทพยายามรัน PyTorch บน AMD GPU แต่เราเห็นว่า PyTorch ผูกกับ ecosystem ของ NVIDIA อย่างลึกซึ้ง ผ่านสิ่งอย่าง
torch.cudaหรือscaled_dot_product_attentionจึงต้องทำ “de-NVIDIA-fication” ค่อนข้างมากเราคิดว่า JAX เหมาะกับฮาร์ดแวร์ที่ไม่ใช่ NVIDIA มากกว่า เพราะโค้ดโมเดลจะถูกคอมไพล์เป็นกราฟ HLO ที่เป็นอิสระจากฮาร์ดแวร์ก่อน จากนั้นคอมไพเลอร์ XLA จะทำการ optimize แล้วค่อยใช้ optimization เฉพาะฮาร์ดแวร์ภายหลัง โดยโค้ด JAX ของ LLaMA3 เดียวกันสามารถทำงานได้ทั้งบน Google TPU และ AMD GPU โดยไม่ต้องแก้ไข
กลยุทธ์ของบริษัทคือพอร์ตโมเดลไปยัง JAX ก่อน แล้วใช้ประโยชน์จากเฟรมเวิร์ก JAX และเคอร์เนล XLA เพื่อดึงประสิทธิภาพสูงสุดจากแบ็กเอนด์ที่ไม่ใช่ NVIDIA ดังนั้นเราจึงย้าย Llama 3.1 จาก PyTorch ไปเป็น JAX ก่อน และโมเดล JAX ตัวเดียวกันก็ทำงานได้ดีทั้งบน TPU และ AMD GPU
โดยส่วนตัวเหตุผลหลักที่ผมใช้ PyTorch ก็เพราะโมเดลต้นฉบับถูกสร้างด้วย PyTorch แม้ลอจิกจะดูเหมือนกันในโมเดลคนละเวอร์ชัน แต่เมื่ออยู่ในสเกลข้อมูลมหาศาล ความคลาดเคลื่อนของ floating point เพียงเล็กน้อยก็อาจสะสมจนเกิด model drift ได้
การดีบักความไม่ตรงกันของความแม่นยำในโมเดลใหญ่ ๆ แบบนี้ แทบจะทรมานยิ่งกว่าวงนรกชั้นที่สิบเสียอีก
hipblaslt, Composable Kernel FA อะไรทำนองนั้นผมไม่ได้รู้จัก JAX ลึกมากนัก แต่คิดว่าส่วนสำคัญที่ทำให้ประสิทธิภาพการเทรนของ PyTorch บน MI300x ย่ำแย่ ก็คือประสิทธิภาพของไลบรารี ROCm ที่ใช้ภายในมันช้า
ที่ว่าทำงานได้ในที่นี้ไม่ได้หมายถึงใช้เวลา 2 สัปดาห์ไล่แก้ไดรเวอร์ แล้วจากนั้นก็ไม่กล้าอัปเดตเซิร์ฟเวอร์อีกเลย
และก็อยากรู้ด้วยว่ามีปัญหาทางเทคนิคอะไรที่เจอบ้าง
พูดกันตรง ๆ ประสิทธิภาพนี้ค่อนข้างแย่ น่าจะเป็นเพราะ คอมไพล์ ยังทำงานได้ไม่ถูกต้อง
สำหรับโมเดล 405B ได้ 35 โทเคนต่อวินาที ซึ่งคิดเป็นประมาณ 85 เทราฟลอปส์ ขณะที่ GPU MI300x จำนวน 8 ตัวอยู่ที่ระดับ 10.4 เพตาฟลอปส์ ดังนั้น MFU อยู่ราว 0.8%
นี่ต่ำกว่าประสิทธิภาพการเทรนที่พอใช้ได้ซึ่งอยู่ที่ 30~40% MFU ถึง 40~50 เท่า ดังนั้นในมุมของ AMD ก็คงหวังว่าคอขวดจะอยู่ที่ซอฟต์แวร์สแตก
ในหน้า GitHub เขียนว่า “สามารถปรับจูน LLaMa3.1 บน Google Cloud TPU ได้ด้วย ต้นทุนต่ำลง 30%” แต่ไม่ได้พูดถึงประสิทธิภาพเลย
งานยอดเยี่ยมมาก เมื่อประมาณปีก่อนผมลองจับ AMD GPU และ การรองรับ ROCm มานิดหน่อย แล้วเห็นได้ชัดว่า AMD ยังต้องไปอีกไกลกว่าจะไล่ทัน Nvidia
แนวทางที่เลือก JAX น่าสนใจ แต่ก็สงสัยว่าการออกห่างจาก PyTorch ซึ่งแทบเป็นไลบรารีมาตรฐานของวงการแมชชีนเลิร์นนิงนั้นมีความยากอะไรบ้าง
ตอนแรกเป้าหมายคือการ fine-tune LLaMA 3 บน TPU แต่ PyTorch XLA ยังเทอะทะอยู่ เราเลยตัดสินใจเขียนโมเดลใหม่ด้วย JAX
อย่างที่บอกไปก่อนหน้านี้ เรามองว่า JAX เป็นแพลตฟอร์มที่ดีกว่าสำหรับ GPU ที่ไม่ใช่ NVIDIA และอยากสร้างอินฟราสตรักเจอร์สำหรับ GPU ที่ไม่ใช่ NVIDIA บน JAX+openXLA
งานดีมาก ช่วงสุดสัปดาห์ที่แล้วผมก็กำลังลองฝั่ง inference ของ 405B อยู่เหมือนกัน [0]
ผมยังไม่แน่ใจว่า
torch.cudaแย่ขนาดนั้นจริงไหม เพราะ PyTorch สำหรับ AMD จะแปลงส่วนนั้นให้เองอยู่แล้ว ดูเป็นเรื่องชื่อเรียกมากกว่าจะเป็นปัญหาเชิงสาระจริง ๆ แล้วการดึงคอนเทนเนอร์
rocm:pytorchก็ง่ายพอ ๆ กับการดึงคอนเทนเนอร์rocm:jaxยังมีตัวเลขที่เผยแพร่ออกมาไม่มากนัก เลยอยากรู้ว่า MFU ได้เท่าไร
[0] https://x.com/HotAisle/status/1837580046732874026
MFU ต้องคำนวณอีกที รายละเอียดของ GPU และ VRAM ดูได้จากรีโพซิทอรี: https://dub.sh/amd-405b-res
สุดสัปดาห์หน้ามีแผนจะลองรันการเทรนอีกครั้งโดยทำ JIT compile ทั้งขั้นตอนการเทรน แล้วค่อยคำนวณ MFU ตอนนั้น
ตอนที่เราวัดกันที่ ZML พบว่า MI300X เร็วกว่า H100 อยู่ 30% เป็นชิปที่ยอดเยี่ยมมาก
สงสัยว่ามี ผู้ให้บริการคลาวด์ เจ้าไหนให้เช่าโฮสต์ 8xAMD MI300 บ้าง
ในงานผมใช้ AWS ค่อนข้างเยอะ แต่อยากลองใช้ AMD GPU ดูสักครั้ง
ข้อมูลประสิทธิภาพ อยู่ที่ไหน?
เนื่องจากข้อจำกัดของโค้ดและ VRAM จึงยังไม่สามารถรันเวอร์ชัน JIT compile ของโมเดล 405B ได้ เรื่องนี้ยังต้องตรวจสอบเพิ่มเติม
การรันเทรนทั้งหมดทำในโหมด eager execution ของ JAX จึงยังมีพื้นที่ให้ปรับปรุงประสิทธิภาพอีกมาก
แม้ในโหมด eager execution การใช้งาน GPU โดยรวมก็ยังอยู่ราว 30~40% ซึ่งถือว่าใช้ได้พอสมควร และคิดว่าถ้าใช้ JIT การใช้งาน GPU ก็น่าจะขึ้นไปถึง 50~60% ได้ไม่ยาก
ถ้าเป็นไปได้ น่าจะน่าสนใจถ้าลองหาวิธีเอาชนะข้อจำกัดด้านหน่วยความจำเพื่อรัน เวอร์ชัน JIT compile เพราะอาจนำไปสู่การปรับปรุงประสิทธิภาพเพิ่มเติมได้
เราจำเป็นต้องมีขั้นตอนการเทรนที่คอมไพล์ด้วย JIT, การโหลดข้อมูลและ sharding ที่ปรับจูนมากขึ้น, gradient accumulation และ activation checkpointing
เรายังคงสร้างต่อไป และตั้งใจว่าจะเขียนบล็อกอีกครั้งในเร็ว ๆ นี้หลังจากทำการปรับปรุงทั้งหมดเสร็จ
สงสัยว่า AMD เข้าใกล้การดึงมูลค่าจากเรื่องนี้ได้บ้างหรือยัง ทั้งในแง่คำสั่งซื้อ GPU จำนวนมากและภาวะขาดแคลนอุปทาน
ความรู้สึกของผมยังเอนไปทาง “ยังไม่ใช่”
อีกฝ่ายมีความได้เปรียบจากการออกตัวก่อนมหาศาล และฝั่งซอฟต์แวร์ก็ยังมีงานให้ทำอีกชัดเจน มันต้องใช้เวลา
ทำไมแอปจดโน้ตอย่าง Obsidian ถึงมาทำเรื่องนี้?