- ทำความเข้าใจโครงสร้างอย่างแม่นยำผ่านการติดตั้งใช้งานโมเดล Llama 3 ที่รันได้จริง
ภาพรวม
- โมเดล Llama 3 ที่ Meta เปิดเผยกำลังได้รับความสนใจอย่างมาก
- โดดเด่นด้วยสเกลและประสิทธิภาพระดับมหาศาล เช่น 24K GPUs, ข้อมูลฝึก 15T, ข้อมูลคำสั่ง 10M และเวลา GPU 1.3M ชั่วโมง
- โครงสร้างของโมเดลไม่ได้เปลี่ยนไปมากนัก โดย Llama 3 ใช้ GQA แต่สิ่งนี้ก็เคยถูกนำไปใช้แล้วใน Llama 2 70B
- มีการติดตั้งใช้งานโดยใช้เพียง NumPy เพื่อให้เข้าใจโครงสร้างของโมเดลได้อย่างเป็นธรรมชาติและตรงไปตรงมา
- ใช้โมเดล stories15M ที่ Andrej Karpathy ฝึกด้วยโครงสร้าง Llama 2 และแปลงเป็นรูปแบบบีบอัดของ NumPy
โครงสร้าง
- โครงสร้างของโมเดล Llama 3 เหมือนกับ 42dot LLM
- พารามิเตอร์ของโมเดล:
dim: 288
n_layers: 6
n_heads: 6
vocab_size: 32000
max_seq_len: 256
max_new_tokens: 50
RoPE #1
- มีการคำนวณ
cos และ sin ล่วงหน้าสำหรับ RoPE embedding
- ค่าชุดนี้ถูกนำไปใช้กับ
Q และ K
- ผลการคำนวณได้จากการคูณด้วย
np.outer แล้วคำนวณ cos และ sin
RMSNorm
- RMSNorm จะทำ normalization ของ activation value ด้วย Root Mean Square แทนการใช้สถิติแบบ Mini Batch หรือ Layer แบบดั้งเดิม
- ช่วยให้การสเกล activation มีความสม่ำเสมอ
QKV
- การคำนวณ QKV ของ Llama แตกต่างจาก GPT ที่ทำ
matmul กับน้ำหนักชุดเดียวแล้วค่อยแยกภายหลัง โดย Llama มีน้ำหนักแยกสำหรับ QKV แต่ละตัว
- จากนั้นจึงจัดรูปแต่ละค่าใหม่เพื่อใช้กับ Multi-Head Attention
RoPE #2
- RoPE มีคุณสมบัติของทั้ง absolute และ relative positional encoding
- ใช้กับ Q และ K เท่านั้น โดยแบ่งอินพุต คูณกับ
cos และ sin แล้วนำผลมาบวกและลบก่อนจัดรูปใหม่
KV cache
- โมเดลสร้างข้อความแบบ GPT ใช้ Masked Attention จึงสามารถใช้ KV cache ได้
- เนื่องจากผลลัพธ์ก่อนหน้าจะคงเดิมเสมอ จึง cache ค่า K และ V ไว้ และคำนวณ Q เฉพาะค่าล่าสุด
GQA(Grouped-Query Attention)
- GQA เป็นเทคนิคที่นำมาใช้ใน Llama 2 เพื่อช่วยประหยัดหน่วยความจำและเพิ่มประสิทธิภาพ
- ใน Llama 3 มีการใช้ GQA กับทุกโมเดลที่มีขนาด 8B ขึ้นไป
Scaled Dot-Product Attention
- คำนวณ Attention แต่ละตัวด้วย Multi-Head Attention
- ผลลัพธ์ได้มาจาก
softmax และ matmul
Feed Forward
- Feed Forward ของโมเดล Llama ใช้ linear layer 3 ชั้น และไม่มี bias
- สร้างค่า swish แล้วคูณกับ
x_V ก่อนลดสเกลลงอีกครั้ง
SwiGLU
- SwiGLU เป็นการผสมผสานที่มีเอกลักษณ์ของ feed forward หลายชั้น ซึ่งช่วยเพิ่มประสิทธิภาพของโมเดล
Linear
- เอาต์พุตสุดท้ายจะคำนวณเฉพาะ logit ตัวสุดท้ายด้วย
matmul เพื่อเพิ่มความเร็ว
การสร้างข้อความ
- ใช้ logit ที่ดึงออกมาเพื่อสร้างโทเคนทีละตัว
- แบ่งเป็น Prefill Phase และ Decode Phase
- ใน Prefill Phase จะส่งอินพุตทั้งหมดเข้าไป ส่วนใน Decode Phase จะส่งเฉพาะ token ID ตัวสุดท้ายเพื่อรับผลลัพธ์
ตัวอย่าง
GitHub
เอกสารอ้างอิง
- Exploring and Building the Llama 3 Architecture
- Rotation Matrix
- Mastering LLM Techniques: Inference Optimization
- arXiv:2305.13245
ความเห็นจาก GN⁺
- โครงสร้างและประสิทธิภาพของโมเดล Llama 3: โมเดล Llama 3 ยังคงโครงสร้างของ Llama 2 เดิมไว้ แต่ยกระดับประสิทธิภาพขึ้นอย่างมาก ซึ่งสะท้อนถึงการออกแบบที่คำนึงถึงทั้งการขยายขนาดและประสิทธิภาพไปพร้อมกัน
- เหตุผลที่ใช้ NumPy ในการติดตั้งใช้งาน: การสร้างโมเดลด้วย NumPy ช่วยให้เข้าใจโครงสร้างและการทำงานของโมเดลได้อย่างเป็นธรรมชาติมากขึ้น ซึ่งเป็นประโยชน์อย่างมากต่อผู้เรียนและนักวิจัย
- การนำ GQA มาใช้: GQA เป็นเทคโนโลยีที่ช่วยทั้งประหยัดหน่วยความจำและเพิ่มประสิทธิภาพ และเมื่อถูกใช้กับทุกโมเดลใน Llama 3 ก็ยิ่งทำให้ประสิทธิภาพโดยรวมของโมเดลสูงสุดขึ้น
- ความสำคัญของ KV cache: KV cache มีบทบาทสำคัญในโมเดลสร้างข้อความสไตล์ GPT และช่วยเพิ่มประสิทธิภาพในการคำนวณของโมเดลได้อย่างมาก
- กรณีใช้งานจริง: สามารถทดลองรันโมเดลจริงได้ผ่านโค้ดตัวอย่าง ซึ่งเป็นโอกาสที่ดีในการตรวจสอบประสิทธิภาพของโมเดลด้วยตนเอง
1 ความคิดเห็น
เดิมทีสิ่งที่ถูกโพสต์บน Hacker News เป็นภาษาอังกฤษ แต่ได้เปลี่ยนเป็นลิงก์ที่คุณ Likejazz ผู้เขียนต้นฉบับเขียนไว้เป็นภาษาเกาหลีแล้ว