2 คะแนน โดย GN⁺ 2024-09-24 | 1 ความคิดเห็น | แชร์ทาง WhatsApp

Felafax BlogTune Llama3 405B on AMD MI300x (เส้นทางของเรา)

บทนำ

  • เมื่อโมเดลโอเพนซอร์สมีขนาดใหญ่ขึ้น ความต้องการโครงสร้างพื้นฐานที่ทรงพลังสำหรับรองรับการฝึก AI ขนาดใหญ่ก็เพิ่มขึ้น
  • Felafax ได้ปรับแต่งโมเดล LLaMA 3.1 405B แบบละเอียดบน AMD GPU เพื่อพิสูจน์ประสิทธิภาพของฮาร์ดแวร์ AMD
  • งานทั้งหมดถูกเผยแพร่เป็นโอเพนซอร์สบน GitHub
  • AMD MI300X GPU ให้ประสิทธิภาพสูงเมื่อเทียบกับฮาร์ดแวร์ AI ของ NVIDIA
  • โปรเจกต์นี้เกิดขึ้นได้ด้วยการสนับสนุนจาก TensorWave

JAX คืออะไร และทำไมจึงเลือกใช้

  • JAX เป็นไลบรารีแมชชีนเลิร์นนิงทรงพลังที่ผสาน API คล้าย NumPy, automatic differentiation และคอมไพเลอร์ XLA ของ Google
  • มี API ที่ยอดเยี่ยมสำหรับการประมวลผลแบบ model parallel จึงเหมาะอย่างยิ่งกับการฝึกโมเดลขนาดใหญ่

ข้อดีของ JAX

  • pure functions: JAX สนับสนุนให้เขียน pure functions ทำให้โค้ดจัดองค์ประกอบได้ง่าย ดีบักง่าย และอ่านง่าย
  • การประมวลผลแบบขนานขั้นสูง: JIT API ที่ยืดหยุ่นของ JAX รองรับ data parallel และ model parallel ขั้นสูง ซึ่งจำเป็นต่อการฝึกขนาดใหญ่
  • codebase ที่สะอาด: ปรัชญาการออกแบบของ JAX ส่งเสริมการเขียนโค้ดที่พกพาข้ามแพลตฟอร์มฮาร์ดแวร์ได้

เหตุใด JAX จึงโดดเด่นบนฮาร์ดแวร์ที่ไม่ใช่ NVIDIA

  • แนวทางที่เป็นอิสระจากฮาร์ดแวร์: JAX ใช้คอมไพเลอร์ XLA เพื่อคอมไพล์การคำนวณเป็นตัวแทนกลางที่ไม่ผูกกับฮาร์ดแวร์
  • การเพิ่มประสิทธิภาพที่เป็นอิสระจากแพลตฟอร์ม: คอมไพเลอร์ XLA ดำเนินการเพิ่มประสิทธิภาพโดยไม่ขึ้นกับฮาร์ดแวร์
  • การย้ายระบบที่ง่าย: เมื่อใช้ JAX การเปลี่ยนจาก NVIDIA ไปเป็น AMD ต้องแก้โค้ดเพียงเล็กน้อย

การตั้งค่า JAX บน AMD GPU

  • ดึง Docker image มาใช้งาน เริ่มต้นคอนเทนเนอร์ แล้วตรวจสอบการติดตั้ง
  • ใช้ AMD MI300x GPU จำนวน 8 ตัวเพื่อฝึกโมเดล LLaMA 405B

การฝึก LLaMA 405B: ประสิทธิภาพและการขยายขนาด

  • ใช้ JAX เพื่อฝึกโมเดล LLaMA 405B บน AMD GPU
  • ปรับแต่งแบบ LoRA โดยตั้งค่า both model weights และพารามิเตอร์ LoRA ด้วยความละเอียดแบบ bfloat16
  • ขนาดโมเดล: ใช้ VRAM ราว 800GB
  • น้ำหนัก LoRA และสถานะออปติไมเซอร์: ใช้ VRAM ราว 400GB
  • การใช้ VRAM รวม: ราว 1200GB
  • ความเร็วในการฝึก: ประมาณ 35 โทเคนต่อวินาที
  • ประสิทธิภาพด้านหน่วยความจำ: รักษาไว้ได้ราว 70%
  • การขยายขนาด: เมื่อใช้ JAX สามารถขยายบน GPU 8 ตัวได้เกือบเป็นเชิงเส้น

การตั้งค่าการฝึกของเรา

  • แปลง LLaMA 3.1 จาก PyTorch ไปเป็น JAX
  • กระจายโมเดลอย่างมีประสิทธิภาพผ่านการโหลดโมเดลและการ sharding ของพารามิเตอร์

การ sharding พารามิเตอร์ใน JAX

  • ใช้ความสามารถ device mesh ของ JAX เพื่อกระจายโมเดลไปยัง AMD GPU 8 ตัวอย่างมีประสิทธิภาพ
  • กำหนดกฎการ sharding ของพารามิเตอร์เพื่อ shard มิติของแต่ละเทนเซอร์ตามแกนของ mesh

การใช้งานการฝึก LoRA

  • LoRA ลดจำนวนพารามิเตอร์ที่ต้องฝึกด้วยการแยกการอัปเดตน้ำหนักออกเป็นเมทริกซ์ low-rank
  • มีการสร้างเลเยอร์ LoRADense เพื่อรวมพารามิเตอร์ LoRA เข้าไป
  • กระจายพารามิเตอร์ LoRA อย่างมีประสิทธิภาพเพื่อเพิ่มประสิทธิภาพการใช้หน่วยความจำและการคำนวณ

บทสรุป

  • ประสบการณ์ในการปรับแต่งโมเดล LLaMA 3.1 405B แบบละเอียดด้วย AMD GPU และ JAX เป็นไปในทางบวกอย่างมาก
  • ใช้ประโยชน์จากความสามารถด้านการประมวลผลแบบขนานอันทรงพลังของ JAX และแนวทางที่เป็นอิสระจากฮาร์ดแวร์เพื่อกระจายโมเดลอย่างมีประสิทธิภาพ
  • พิสูจน์ให้เห็นว่า AMD GPU เป็นทางเลือกที่ทรงพลังสำหรับการฝึก AI ขนาดใหญ่
  • สามารถดูโค้ดทั้งหมดและรันได้ด้วยตนเองจาก GitHub repository

สรุปโดย GN⁺

  • บทความนี้อธิบายวิธีฝึกโมเดล AI ขนาดใหญ่ให้มีประสิทธิภาพด้วย AMD GPU และ JAX
  • เน้นว่า AMD hardware เป็นทางเลือกที่คุ้มค่ากว่าเมื่อเทียบกับ NVIDIA
  • แนวทางที่เป็นอิสระจากฮาร์ดแวร์ของ JAX ช่วยเพิ่มความสามารถในการพกพาโค้ดและทำให้ดูแลรักษาได้ง่ายขึ้น
  • ให้ข้อมูลที่มีประโยชน์และโค้ดสำหรับลงมือทำแก่ผู้ที่สนใจการฝึกโมเดลขนาดใหญ่
  • โปรเจกต์ที่มีความสามารถคล้ายกัน ได้แก่ CUDA และ PyTorch ของ NVIDIA

1 ความคิดเห็น

 
GN⁺ 2024-09-24
ความคิดเห็นจาก Hacker News
  • เมื่อไม่นานมานี้ได้ทำการ fine-tune โมเดล llama3.1 405B บน 8xAMD MI300x GPU โดยใช้ JAX แทน PyTorch
    ด้วย API สำหรับ sharding ขั้นสูงของ JAX จึงได้ประสิทธิภาพที่ดี และได้สรุปเทคนิค sharding ที่ใช้ไว้ในบล็อกแล้ว นอกจากนี้ยังเปิดเผยโค้ดด้วย: https://github.com/felafax/felafax
    เราเป็นสตาร์ตอัปเล็ก ๆ ที่สร้างโครงสร้างพื้นฐาน AI สำหรับการ fine-tune และ serving ของ LLM บนฮาร์ดแวร์ที่ไม่ใช่ NVIDIA เช่น TPU, AMD, Trainium
    หลายบริษัทพยายามรัน PyTorch บน AMD GPU แต่เราเห็นว่า PyTorch ผูกกับ ecosystem ของ NVIDIA อย่างลึกซึ้ง ผ่านสิ่งอย่าง torch.cuda หรือ scaled_dot_product_attention จึงต้องทำ “de-NVIDIA-fication” ค่อนข้างมาก
    เราคิดว่า JAX เหมาะกับฮาร์ดแวร์ที่ไม่ใช่ NVIDIA มากกว่า เพราะโค้ดโมเดลจะถูกคอมไพล์เป็นกราฟ HLO ที่เป็นอิสระจากฮาร์ดแวร์ก่อน จากนั้นคอมไพเลอร์ XLA จะทำการ optimize แล้วค่อยใช้ optimization เฉพาะฮาร์ดแวร์ภายหลัง โดยโค้ด JAX ของ LLaMA3 เดียวกันสามารถทำงานได้ทั้งบน Google TPU และ AMD GPU โดยไม่ต้องแก้ไข
    กลยุทธ์ของบริษัทคือพอร์ตโมเดลไปยัง JAX ก่อน แล้วใช้ประโยชน์จากเฟรมเวิร์ก JAX และเคอร์เนล XLA เพื่อดึงประสิทธิภาพสูงสุดจากแบ็กเอนด์ที่ไม่ใช่ NVIDIA ดังนั้นเราจึงย้าย Llama 3.1 จาก PyTorch ไปเป็น JAX ก่อน และโมเดล JAX ตัวเดียวกันก็ทำงานได้ดีทั้งบน TPU และ AMD GPU

    • การรัน PyTorch บน AMD GPU โดยไม่ต้องแก้โค้ด CUDA แทบไม่มีปัญหาอะไรเลย บล็อกของ MosaicML ก็น่าอ่านเช่นกัน: https://www.databricks.com/blog/training-llms-scale-amd-mi25...
    • สงสัยว่าตรวจสอบ ความแม่นยำของการพอร์ต JAX สำหรับ Llama 3.1 อย่างไร
      โดยส่วนตัวเหตุผลหลักที่ผมใช้ PyTorch ก็เพราะโมเดลต้นฉบับถูกสร้างด้วย PyTorch แม้ลอจิกจะดูเหมือนกันในโมเดลคนละเวอร์ชัน แต่เมื่ออยู่ในสเกลข้อมูลมหาศาล ความคลาดเคลื่อนของ floating point เพียงเล็กน้อยก็อาจสะสมจนเกิด model drift ได้
      การดีบักความไม่ตรงกันของความแม่นยำในโมเดลใหญ่ ๆ แบบนี้ แทบจะทรมานยิ่งกว่าวงนรกชั้นที่สิบเสียอีก
    • สงสัยว่า JAX มี implementation ของ matrix multiplication หรือ FlashAttention เป็นของตัวเอง หรือใช้ implementation ของ ROCm แบบเดียวกับ PyTorch เช่น hipblaslt, Composable Kernel FA อะไรทำนองนั้น
      ผมไม่ได้รู้จัก JAX ลึกมากนัก แต่คิดว่าส่วนสำคัญที่ทำให้ประสิทธิภาพการเทรนของ PyTorch บน MI300x ย่ำแย่ ก็คือประสิทธิภาพของไลบรารี ROCm ที่ใช้ภายในมันช้า
    • สงสัยว่าจะทำงานได้บน การ์ดผู้บริโภค อย่าง 7900 XTX หรือไม่
      ที่ว่าทำงานได้ในที่นี้ไม่ได้หมายถึงใช้เวลา 2 สัปดาห์ไล่แก้ไดรเวอร์ แล้วจากนั้นก็ไม่กล้าอัปเดตเซิร์ฟเวอร์อีกเลย
    • ถ้าเป็นการย้ายระบบ ก็อยากเห็น ตัวเลขจริงเมื่อเทียบกับเวอร์ชัน PyTorch ของโมเดลเดียวกัน ตารางเปรียบเทียบในบทความดูใกล้เคียงกับด้านเทคนิคมากกว่า
      และก็อยากรู้ด้วยว่ามีปัญหาทางเทคนิคอะไรที่เจอบ้าง
  • พูดกันตรง ๆ ประสิทธิภาพนี้ค่อนข้างแย่ น่าจะเป็นเพราะ คอมไพล์ ยังทำงานได้ไม่ถูกต้อง
    สำหรับโมเดล 405B ได้ 35 โทเคนต่อวินาที ซึ่งคิดเป็นประมาณ 85 เทราฟลอปส์ ขณะที่ GPU MI300x จำนวน 8 ตัวอยู่ที่ระดับ 10.4 เพตาฟลอปส์ ดังนั้น MFU อยู่ราว 0.8%
    นี่ต่ำกว่าประสิทธิภาพการเทรนที่พอใช้ได้ซึ่งอยู่ที่ 30~40% MFU ถึง 40~50 เท่า ดังนั้นในมุมของ AMD ก็คงหวังว่าคอขวดจะอยู่ที่ซอฟต์แวร์สแตก

    • ผมก็อยากถามเรื่องนี้เหมือนกัน
      ในหน้า GitHub เขียนว่า “สามารถปรับจูน LLaMa3.1 บน Google Cloud TPU ได้ด้วย ต้นทุนต่ำลง 30%” แต่ไม่ได้พูดถึงประสิทธิภาพเลย
  • งานยอดเยี่ยมมาก เมื่อประมาณปีก่อนผมลองจับ AMD GPU และ การรองรับ ROCm มานิดหน่อย แล้วเห็นได้ชัดว่า AMD ยังต้องไปอีกไกลกว่าจะไล่ทัน Nvidia
    แนวทางที่เลือก JAX น่าสนใจ แต่ก็สงสัยว่าการออกห่างจาก PyTorch ซึ่งแทบเป็นไลบรารีมาตรฐานของวงการแมชชีนเลิร์นนิงนั้นมีความยากอะไรบ้าง

    • เมื่อไม่กี่สัปดาห์ก่อนเราโพสต์ Show HN อธิบายเส้นทางของเราไว้: https://news.ycombinator.com/item?id=41512142
      ตอนแรกเป้าหมายคือการ fine-tune LLaMA 3 บน TPU แต่ PyTorch XLA ยังเทอะทะอยู่ เราเลยตัดสินใจเขียนโมเดลใหม่ด้วย JAX
      อย่างที่บอกไปก่อนหน้านี้ เรามองว่า JAX เป็นแพลตฟอร์มที่ดีกว่าสำหรับ GPU ที่ไม่ใช่ NVIDIA และอยากสร้างอินฟราสตรักเจอร์สำหรับ GPU ที่ไม่ใช่ NVIDIA บน JAX+openXLA
    • ผมยังทำให้ AMD ROCm ทำงานบนระบบ Debian 12 ของตัวเองไม่ได้เลย เลยคิดว่า Ollama กำลังใช้ CPU แทน GPU ดูแล้วคงต้องไปอีกไกล
  • งานดีมาก ช่วงสุดสัปดาห์ที่แล้วผมก็กำลังลองฝั่ง inference ของ 405B อยู่เหมือนกัน [0]
    ผมยังไม่แน่ใจว่า torch.cuda แย่ขนาดนั้นจริงไหม เพราะ PyTorch สำหรับ AMD จะแปลงส่วนนั้นให้เองอยู่แล้ว ดูเป็นเรื่องชื่อเรียกมากกว่าจะเป็นปัญหาเชิงสาระ
    จริง ๆ แล้วการดึงคอนเทนเนอร์ rocm:pytorch ก็ง่ายพอ ๆ กับการดึงคอนเทนเนอร์ rocm:jax
    ยังมีตัวเลขที่เผยแพร่ออกมาไม่มากนัก เลยอยากรู้ว่า MFU ได้เท่าไร
    [0] https://x.com/HotAisle/status/1837580046732874026

    • ดีเลย
      MFU ต้องคำนวณอีกที รายละเอียดของ GPU และ VRAM ดูได้จากรีโพซิทอรี: https://dub.sh/amd-405b-res
      สุดสัปดาห์หน้ามีแผนจะลองรันการเทรนอีกครั้งโดยทำ JIT compile ทั้งขั้นตอนการเทรน แล้วค่อยคำนวณ MFU ตอนนั้น
  • ตอนที่เราวัดกันที่ ZML พบว่า MI300X เร็วกว่า H100 อยู่ 30% เป็นชิปที่ยอดเยี่ยมมาก

  • สงสัยว่ามี ผู้ให้บริการคลาวด์ เจ้าไหนให้เช่าโฮสต์ 8xAMD MI300 บ้าง
    ในงานผมใช้ AWS ค่อนข้างเยอะ แต่อยากลองใช้ AMD GPU ดูสักครั้ง

    • FYI บริษัทของเราเปิดให้เช่า 8xMI300x อยู่ ดังนั้นติดต่อมาได้
    • Oracle มีให้บริการ ที่อื่นก็น่าจะทยอยตามมา แต่ผมคิดว่าคงสมเหตุสมผลกว่าถ้าจะคุยกับผู้ให้บริการรายเล็ก
  • ข้อมูลประสิทธิภาพ อยู่ที่ไหน?

    • ได้เพิ่มข้อมูลการใช้งาน GPU และ VRAM ลงในรีโพซิทอรี GitHub แล้ว: https://github.com/felafax/felafax?tab=readme-ov-file#amd-40...
      เนื่องจากข้อจำกัดของโค้ดและ VRAM จึงยังไม่สามารถรันเวอร์ชัน JIT compile ของโมเดล 405B ได้ เรื่องนี้ยังต้องตรวจสอบเพิ่มเติม
      การรันเทรนทั้งหมดทำในโหมด eager execution ของ JAX จึงยังมีพื้นที่ให้ปรับปรุงประสิทธิภาพอีกมาก
      แม้ในโหมด eager execution การใช้งาน GPU โดยรวมก็ยังอยู่ราว 30~40% ซึ่งถือว่าใช้ได้พอสมควร และคิดว่าถ้าใช้ JIT การใช้งาน GPU ก็น่าจะขึ้นไปถึง 50~60% ได้ไม่ยาก
  • ถ้าเป็นไปได้ น่าจะน่าสนใจถ้าลองหาวิธีเอาชนะข้อจำกัดด้านหน่วยความจำเพื่อรัน เวอร์ชัน JIT compile เพราะอาจนำไปสู่การปรับปรุงประสิทธิภาพเพิ่มเติมได้

    • เห็นด้วย ยังมีประสิทธิภาพให้รีดออกมาอีกมาก
      เราจำเป็นต้องมีขั้นตอนการเทรนที่คอมไพล์ด้วย JIT, การโหลดข้อมูลและ sharding ที่ปรับจูนมากขึ้น, gradient accumulation และ activation checkpointing
      เรายังคงสร้างต่อไป และตั้งใจว่าจะเขียนบล็อกอีกครั้งในเร็ว ๆ นี้หลังจากทำการปรับปรุงทั้งหมดเสร็จ
  • สงสัยว่า AMD เข้าใกล้การดึงมูลค่าจากเรื่องนี้ได้บ้างหรือยัง ทั้งในแง่คำสั่งซื้อ GPU จำนวนมากและภาวะขาดแคลนอุปทาน
    ความรู้สึกของผมยังเอนไปทาง “ยังไม่ใช่”

    • เข้าใจว่าเป็นคำพูดเชิงเหน็บแนม แต่ถ้าตอนนี้คุณไม่ได้คิดจะฝากฮาร์ดแวร์และซอฟต์แวร์ AI ทั้งหมดไว้กับผู้ขายรายเดียว ก็ควรเริ่มขยับไปหาทางเลือกอื่นได้แล้ว
      อีกฝ่ายมีความได้เปรียบจากการออกตัวก่อนมหาศาล และฝั่งซอฟต์แวร์ก็ยังมีงานให้ทำอีกชัดเจน มันต้องใช้เวลา
  • ทำไมแอปจดโน้ตอย่าง Obsidian ถึงมาทำเรื่องนี้?

    • ไม่ใช่แบบนั้น บริษัทนี้แค่ใช้ Obsidian Publish สำหรับเผยแพร่เอกสาร