20 คะแนน โดย GN⁺ 2024-05-17 | 1 ความคิดเห็น | แชร์ทาง WhatsApp
  • ทำความเข้าใจโครงสร้างอย่างแม่นยำผ่านการติดตั้งใช้งานโมเดล Llama 3 ที่รันได้จริง

ภาพรวม

  • โมเดล Llama 3 ที่ Meta เปิดเผยกำลังได้รับความสนใจอย่างมาก
  • โดดเด่นด้วยสเกลและประสิทธิภาพระดับมหาศาล เช่น 24K GPUs, ข้อมูลฝึก 15T, ข้อมูลคำสั่ง 10M และเวลา GPU 1.3M ชั่วโมง
  • โครงสร้างของโมเดลไม่ได้เปลี่ยนไปมากนัก โดย Llama 3 ใช้ GQA แต่สิ่งนี้ก็เคยถูกนำไปใช้แล้วใน Llama 2 70B
  • มีการติดตั้งใช้งานโดยใช้เพียง NumPy เพื่อให้เข้าใจโครงสร้างของโมเดลได้อย่างเป็นธรรมชาติและตรงไปตรงมา
  • ใช้โมเดล stories15M ที่ Andrej Karpathy ฝึกด้วยโครงสร้าง Llama 2 และแปลงเป็นรูปแบบบีบอัดของ NumPy

โครงสร้าง

  • โครงสร้างของโมเดล Llama 3 เหมือนกับ 42dot LLM
  • พารามิเตอร์ของโมเดล:
    • dim: 288
    • n_layers: 6
    • n_heads: 6
    • vocab_size: 32000
    • max_seq_len: 256
    • max_new_tokens: 50

RoPE #1

  • มีการคำนวณ cos และ sin ล่วงหน้าสำหรับ RoPE embedding
  • ค่าชุดนี้ถูกนำไปใช้กับ Q และ K
  • ผลการคำนวณได้จากการคูณด้วย np.outer แล้วคำนวณ cos และ sin

RMSNorm

  • RMSNorm จะทำ normalization ของ activation value ด้วย Root Mean Square แทนการใช้สถิติแบบ Mini Batch หรือ Layer แบบดั้งเดิม
  • ช่วยให้การสเกล activation มีความสม่ำเสมอ

QKV

  • การคำนวณ QKV ของ Llama แตกต่างจาก GPT ที่ทำ matmul กับน้ำหนักชุดเดียวแล้วค่อยแยกภายหลัง โดย Llama มีน้ำหนักแยกสำหรับ QKV แต่ละตัว
  • จากนั้นจึงจัดรูปแต่ละค่าใหม่เพื่อใช้กับ Multi-Head Attention

RoPE #2

  • RoPE มีคุณสมบัติของทั้ง absolute และ relative positional encoding
  • ใช้กับ Q และ K เท่านั้น โดยแบ่งอินพุต คูณกับ cos และ sin แล้วนำผลมาบวกและลบก่อนจัดรูปใหม่

KV cache

  • โมเดลสร้างข้อความแบบ GPT ใช้ Masked Attention จึงสามารถใช้ KV cache ได้
  • เนื่องจากผลลัพธ์ก่อนหน้าจะคงเดิมเสมอ จึง cache ค่า K และ V ไว้ และคำนวณ Q เฉพาะค่าล่าสุด

GQA(Grouped-Query Attention)

  • GQA เป็นเทคนิคที่นำมาใช้ใน Llama 2 เพื่อช่วยประหยัดหน่วยความจำและเพิ่มประสิทธิภาพ
  • ใน Llama 3 มีการใช้ GQA กับทุกโมเดลที่มีขนาด 8B ขึ้นไป

Scaled Dot-Product Attention

  • คำนวณ Attention แต่ละตัวด้วย Multi-Head Attention
  • ผลลัพธ์ได้มาจาก softmax และ matmul

Feed Forward

  • Feed Forward ของโมเดล Llama ใช้ linear layer 3 ชั้น และไม่มี bias
  • สร้างค่า swish แล้วคูณกับ x_V ก่อนลดสเกลลงอีกครั้ง

SwiGLU

  • SwiGLU เป็นการผสมผสานที่มีเอกลักษณ์ของ feed forward หลายชั้น ซึ่งช่วยเพิ่มประสิทธิภาพของโมเดล

Linear

  • เอาต์พุตสุดท้ายจะคำนวณเฉพาะ logit ตัวสุดท้ายด้วย matmul เพื่อเพิ่มความเร็ว

การสร้างข้อความ

  • ใช้ logit ที่ดึงออกมาเพื่อสร้างโทเคนทีละตัว
  • แบ่งเป็น Prefill Phase และ Decode Phase
  • ใน Prefill Phase จะส่งอินพุตทั้งหมดเข้าไป ส่วนใน Decode Phase จะส่งเฉพาะ token ID ตัวสุดท้ายเพื่อรับผลลัพธ์

ตัวอย่าง

  • สามารถรันได้ดังนี้:
    $ python llama3.py "I have a dream"  
    

GitHub

  • ดูซอร์สโค้ดทั้งหมดได้ที่ likejazz/llama3.np

เอกสารอ้างอิง

  1. Exploring and Building the Llama 3 Architecture
  2. Rotation Matrix
  3. Mastering LLM Techniques: Inference Optimization
  4. arXiv:2305.13245

ความเห็นจาก GN⁺

  • โครงสร้างและประสิทธิภาพของโมเดล Llama 3: โมเดล Llama 3 ยังคงโครงสร้างของ Llama 2 เดิมไว้ แต่ยกระดับประสิทธิภาพขึ้นอย่างมาก ซึ่งสะท้อนถึงการออกแบบที่คำนึงถึงทั้งการขยายขนาดและประสิทธิภาพไปพร้อมกัน
  • เหตุผลที่ใช้ NumPy ในการติดตั้งใช้งาน: การสร้างโมเดลด้วย NumPy ช่วยให้เข้าใจโครงสร้างและการทำงานของโมเดลได้อย่างเป็นธรรมชาติมากขึ้น ซึ่งเป็นประโยชน์อย่างมากต่อผู้เรียนและนักวิจัย
  • การนำ GQA มาใช้: GQA เป็นเทคโนโลยีที่ช่วยทั้งประหยัดหน่วยความจำและเพิ่มประสิทธิภาพ และเมื่อถูกใช้กับทุกโมเดลใน Llama 3 ก็ยิ่งทำให้ประสิทธิภาพโดยรวมของโมเดลสูงสุดขึ้น
  • ความสำคัญของ KV cache: KV cache มีบทบาทสำคัญในโมเดลสร้างข้อความสไตล์ GPT และช่วยเพิ่มประสิทธิภาพในการคำนวณของโมเดลได้อย่างมาก
  • กรณีใช้งานจริง: สามารถทดลองรันโมเดลจริงได้ผ่านโค้ดตัวอย่าง ซึ่งเป็นโอกาสที่ดีในการตรวจสอบประสิทธิภาพของโมเดลด้วยตนเอง

1 ความคิดเห็น

 
xguru 2024-05-17

เดิมทีสิ่งที่ถูกโพสต์บน Hacker News เป็นภาษาอังกฤษ แต่ได้เปลี่ยนเป็นลิงก์ที่คุณ Likejazz ผู้เขียนต้นฉบับเขียนไว้เป็นภาษาเกาหลีแล้ว