2 คะแนน โดย GN⁺ 2024-09-24 | 1 ความคิดเห็น | แชร์ทาง WhatsApp

Felafax BlogTune Llama3 405B on AMD MI300x (เส้นทางของเรา)

บทนำ

  • เมื่อโมเดลโอเพนซอร์สมีขนาดใหญ่ขึ้น ความต้องการโครงสร้างพื้นฐานที่ทรงพลังสำหรับรองรับการฝึก AI ขนาดใหญ่ก็เพิ่มขึ้น
  • Felafax ได้ปรับแต่งโมเดล LLaMA 3.1 405B แบบละเอียดบน AMD GPU เพื่อพิสูจน์ประสิทธิภาพของฮาร์ดแวร์ AMD
  • งานทั้งหมดถูกเผยแพร่เป็นโอเพนซอร์สบน GitHub
  • AMD MI300X GPU ให้ประสิทธิภาพสูงเมื่อเทียบกับฮาร์ดแวร์ AI ของ NVIDIA
  • โปรเจกต์นี้เกิดขึ้นได้ด้วยการสนับสนุนจาก TensorWave

JAX คืออะไร และทำไมจึงเลือกใช้

  • JAX เป็นไลบรารีแมชชีนเลิร์นนิงทรงพลังที่ผสาน API คล้าย NumPy, automatic differentiation และคอมไพเลอร์ XLA ของ Google
  • มี API ที่ยอดเยี่ยมสำหรับการประมวลผลแบบ model parallel จึงเหมาะอย่างยิ่งกับการฝึกโมเดลขนาดใหญ่

ข้อดีของ JAX

  • pure functions: JAX สนับสนุนให้เขียน pure functions ทำให้โค้ดจัดองค์ประกอบได้ง่าย ดีบักง่าย และอ่านง่าย
  • การประมวลผลแบบขนานขั้นสูง: JIT API ที่ยืดหยุ่นของ JAX รองรับ data parallel และ model parallel ขั้นสูง ซึ่งจำเป็นต่อการฝึกขนาดใหญ่
  • codebase ที่สะอาด: ปรัชญาการออกแบบของ JAX ส่งเสริมการเขียนโค้ดที่พกพาข้ามแพลตฟอร์มฮาร์ดแวร์ได้

เหตุใด JAX จึงโดดเด่นบนฮาร์ดแวร์ที่ไม่ใช่ NVIDIA

  • แนวทางที่เป็นอิสระจากฮาร์ดแวร์: JAX ใช้คอมไพเลอร์ XLA เพื่อคอมไพล์การคำนวณเป็นตัวแทนกลางที่ไม่ผูกกับฮาร์ดแวร์
  • การเพิ่มประสิทธิภาพที่เป็นอิสระจากแพลตฟอร์ม: คอมไพเลอร์ XLA ดำเนินการเพิ่มประสิทธิภาพโดยไม่ขึ้นกับฮาร์ดแวร์
  • การย้ายระบบที่ง่าย: เมื่อใช้ JAX การเปลี่ยนจาก NVIDIA ไปเป็น AMD ต้องแก้โค้ดเพียงเล็กน้อย

การตั้งค่า JAX บน AMD GPU

  • ดึง Docker image มาใช้งาน เริ่มต้นคอนเทนเนอร์ แล้วตรวจสอบการติดตั้ง
  • ใช้ AMD MI300x GPU จำนวน 8 ตัวเพื่อฝึกโมเดล LLaMA 405B

การฝึก LLaMA 405B: ประสิทธิภาพและการขยายขนาด

  • ใช้ JAX เพื่อฝึกโมเดล LLaMA 405B บน AMD GPU
  • ปรับแต่งแบบ LoRA โดยตั้งค่า both model weights และพารามิเตอร์ LoRA ด้วยความละเอียดแบบ bfloat16
  • ขนาดโมเดล: ใช้ VRAM ราว 800GB
  • น้ำหนัก LoRA และสถานะออปติไมเซอร์: ใช้ VRAM ราว 400GB
  • การใช้ VRAM รวม: ราว 1200GB
  • ความเร็วในการฝึก: ประมาณ 35 โทเคนต่อวินาที
  • ประสิทธิภาพด้านหน่วยความจำ: รักษาไว้ได้ราว 70%
  • การขยายขนาด: เมื่อใช้ JAX สามารถขยายบน GPU 8 ตัวได้เกือบเป็นเชิงเส้น

การตั้งค่าการฝึกของเรา

  • แปลง LLaMA 3.1 จาก PyTorch ไปเป็น JAX
  • กระจายโมเดลอย่างมีประสิทธิภาพผ่านการโหลดโมเดลและการ sharding ของพารามิเตอร์

การ sharding พารามิเตอร์ใน JAX

  • ใช้ความสามารถ device mesh ของ JAX เพื่อกระจายโมเดลไปยัง AMD GPU 8 ตัวอย่างมีประสิทธิภาพ
  • กำหนดกฎการ sharding ของพารามิเตอร์เพื่อ shard มิติของแต่ละเทนเซอร์ตามแกนของ mesh

การใช้งานการฝึก LoRA

  • LoRA ลดจำนวนพารามิเตอร์ที่ต้องฝึกด้วยการแยกการอัปเดตน้ำหนักออกเป็นเมทริกซ์ low-rank
  • มีการสร้างเลเยอร์ LoRADense เพื่อรวมพารามิเตอร์ LoRA เข้าไป
  • กระจายพารามิเตอร์ LoRA อย่างมีประสิทธิภาพเพื่อเพิ่มประสิทธิภาพการใช้หน่วยความจำและการคำนวณ

บทสรุป

  • ประสบการณ์ในการปรับแต่งโมเดล LLaMA 3.1 405B แบบละเอียดด้วย AMD GPU และ JAX เป็นไปในทางบวกอย่างมาก
  • ใช้ประโยชน์จากความสามารถด้านการประมวลผลแบบขนานอันทรงพลังของ JAX และแนวทางที่เป็นอิสระจากฮาร์ดแวร์เพื่อกระจายโมเดลอย่างมีประสิทธิภาพ
  • พิสูจน์ให้เห็นว่า AMD GPU เป็นทางเลือกที่ทรงพลังสำหรับการฝึก AI ขนาดใหญ่
  • สามารถดูโค้ดทั้งหมดและรันได้ด้วยตนเองจาก GitHub repository

สรุปโดย GN⁺

  • บทความนี้อธิบายวิธีฝึกโมเดล AI ขนาดใหญ่ให้มีประสิทธิภาพด้วย AMD GPU และ JAX
  • เน้นว่า AMD hardware เป็นทางเลือกที่คุ้มค่ากว่าเมื่อเทียบกับ NVIDIA
  • แนวทางที่เป็นอิสระจากฮาร์ดแวร์ของ JAX ช่วยเพิ่มความสามารถในการพกพาโค้ดและทำให้ดูแลรักษาได้ง่ายขึ้น
  • ให้ข้อมูลที่มีประโยชน์และโค้ดสำหรับลงมือทำแก่ผู้ที่สนใจการฝึกโมเดลขนาดใหญ่
  • โปรเจกต์ที่มีความสามารถคล้ายกัน ได้แก่ CUDA และ PyTorch ของ NVIDIA

1 ความคิดเห็น

 
GN⁺ 2024-09-24
ความคิดเห็นบน Hacker News
  • แชร์ผลงานการทำ fine-tune โมเดล Llama3.1 405B บน GPU 8xAMD MI300x โดยใช้ JAX

    • ทำประสิทธิภาพได้ยอดเยี่ยมด้วย sharding API ระดับสูงของ JAX
    • มีลิงก์ไปยังบล็อกโพสต์และโค้ดโอเพนซอร์ส: ลิงก์ GitHub
    • เป็นสตาร์ตอัปที่กำลังสร้าง AI infrastructure สำหรับ fine-tune และให้บริการ LLM บน TPU, AMD และ Trainium แทนฮาร์ดแวร์ NVIDIA
    • มองว่าหลายบริษัทพยายามทำให้ PyTorch รันบน AMD GPU แต่เป็นเส้นทางที่ยาก
    • PyTorch ผูกกับ ecosystem ของ NVIDIA อย่างลึกซึ้ง จึงต้องแก้ไขหลายอย่างหากจะให้ทำงานบนฮาร์ดแวร์ที่ไม่ใช่ NVIDIA
    • เชื่อว่า JAX เหมาะกับฮาร์ดแวร์ที่ไม่ใช่ NVIDIA มากกว่า
    • ใน JAX โค้ดโมเดล ML จะถูกคอมไพล์เป็นกราฟ HLO ที่เป็นอิสระจากฮาร์ดแวร์ และ XLA compiler จะทำ optimization เฉพาะฮาร์ดแวร์
    • สามารถรันโค้ด JAX เดียวกันบน Google TPU และ AMD GPU ได้โดยไม่ต้องแก้ไข
    • กลยุทธ์ของบริษัทคือพอร์ตโมเดลไปยัง JAX และใช้ XLA kernel เพื่อดึงประสิทธิภาพสูงสุดจาก backend ที่ไม่ใช่ NVIDIA
    • ได้พอร์ต Llama 3.1 จาก PyTorch ไปยัง JAX เป็นครั้งแรก และตอนนี้โมเดล JAX เดียวกันก็ทำงานได้ดีทั้งบน TPU และ AMD GPU
    • อยากรับฟังความคิดเห็นเกี่ยวกับวิสัยทัศน์และรีโพซิทอรี
  • มีข้อเสนอให้สำรวจวิธีเอาชนะข้อจำกัดด้านหน่วยความจำและรันเวอร์ชันที่คอมไพล์ด้วย JIT

    • น่าจะช่วยเพิ่มประสิทธิภาพได้อีก
  • แชร์ประสบการณ์เกี่ยวกับ AMD GPU และการรองรับ ROCm

    • เคยลองใช้ AMD GPU และการรองรับ ROCm เมื่อปีที่แล้ว แต่รู้สึกว่า AMD ยังห่างไกลจากการไล่ทัน NVIDIA
    • การเลือก JAX เป็นแนวทางที่น่าสนใจ แต่สงสัยว่าการออกจาก PyTorch มีความยากลำบากอะไรบ้าง
  • แชร์ประสบการณ์ทดลองในด้าน inference ของโมเดล 405B

    • คิดว่า torch.cuda ก็ไม่ได้แย่ขนาดนั้น
    • มองว่าเป็นแค่ปัญหาเรื่องชื่อ เพราะ PyTorch เวอร์ชันของ AMD จะแปลส่วนนี้ให้เอง
    • การใช้คอนเทนเนอร์ rocm:pytorch ก็ง่ายพอ ๆ กับการใช้คอนเทนเนอร์ rocm:jax
    • ชี้ว่ามีการเผยแพร่ข้อมูลประสิทธิภาพออกมาน้อยมาก
    • สงสัยเกี่ยวกับตัวเลข MFU (อัตราการใช้ประโยชน์ของโมเดล)
  • ตั้งคำถามถึงการไม่มีข้อมูลประสิทธิภาพ

    • ตั้งข้อสงสัยว่าการสั่งซื้อ AMD GPU จำนวนมากจะสามารถดึงมูลค่าออกมาได้จริงหรือไม่
    • ให้ความรู้สึกว่าคำตอบคือ "ไม่"
  • สงสัยว่าทำไม Obsidian (แอปจดโน้ต) ถึงมาทำเรื่องนี้

    • ตอนแรกคิดว่าเป็นโพสต์ของ Obsidian
    • สงสัยว่าทำไมถึงยังแยกไม่ออกระหว่าง GitHub.com กับ GitHub.io
  • ขอให้ @dang ใส่ชื่อผู้ใช้ไว้ใน URL

    • โพสต์นี้เกี่ยวกับบล็อกที่ผู้ใช้สร้างขึ้น ไม่ใช่ของ Obsidian เอง