Llama จากศูนย์: วิธีอิมพลีเมนต์เปเปอร์โดยไม่ต้องเสียน้ำตา
(blog.briankitano.com)- Brian Kitano สร้าง Llama ฉบับย่อด้วย TinyShakespeare ด้วยตัวเอง และสรุปว่าการอิมพลีเมนต์เปเปอร์ให้ปลอดภัยควรเริ่มจากโมเดลขนาดเล็ก ค่อย ๆ เปลี่ยนชิ้นส่วนทีละส่วน พร้อมเทรนและประเมินผลทุกครั้ง
- เริ่มจากเตรียม ฟังก์ชันช่วยตรวจสอบ เช่น การแบ่งข้อมูล การสร้าง batch การประเมิน loss และฟังก์ชัน generate จากนั้นใช้โมเดลง่าย ๆ ตรวจว่าคอมไพล์และเทรนได้จริง ก่อนเพิ่มองค์ประกอบของ Llama เข้าไป
- เพิ่ม RMSNorm, RoPE, SwiGLU ตามลำดับ พร้อมตรวจสอบว่าแต่ละเลเยอร์ทำงานตามที่คาดหรือไม่ผ่าน tensor shape, คุณสมบัติของสูตร และ attention map
- ใน RoPE attention หากเอา causal mask ออก validation loss จะลดลงถึง 0.16 แต่คุณภาพการ generate แย่ลง สาเหตุคือ ข้อมูลรั่วไหล จากการมองเห็น token ในอนาคต
- Llama ฉบับย่อสุดท้ายมี 4 บล็อก พารามิเตอร์ราว 2.37 ล้านตัว ลด validation loss ลงได้ประมาณ 1.0 และยังต้องตรวจสอบทั้งการไหลของ gradient และ learning-rate schedule ควบคู่กันด้วย
เริ่มจากเล็ก ๆ แล้วค่อยสร้างความมั่นใจแบบวนซ้ำ
- หัวใจของการอิมพลีเมนต์เปเปอร์คือเริ่มจาก โมเดลขนาดเล็ก ค่อย ๆ เปลี่ยนองค์ประกอบทีละส่วน และทำซ้ำการเทรนกับการประเมินผลทุกครั้งที่เปลี่ยน
- ก่อนอื่นเตรียมฟังก์ชันช่วยสำหรับตรวจสอบโมเดลในเชิงปริมาณ
- การแบ่งข้อมูล
- training loop
- การแสดงภาพ loss
- การประเมิน validation loss
- แทนที่จะย้ายองค์ประกอบทั้งหมดจากเปเปอร์มาทีเดียว ให้เตรียม ฟังก์ชันประเมินเชิงคุณภาพ สำหรับดูผลลัพธ์การ generate ด้วยโมเดลที่เรียบง่าย รันเร็ว และเคยมีประสบการณ์อิมพลีเมนต์มาก่อน
- ตรวจสอบ tensor layer ด้วย
.shape,assert,plt.imshowและแทนที่จะรีบไปทำ optimization ของ matrix multiplication ตั้งแต่แรก ให้คำนวณตรวจผลลัพธ์ที่คาดไว้ด้วยมือก่อน แล้วค่อยทำให้มีประสิทธิภาพขึ้นด้วยฟังก์ชันของtorch - ต้องทดสอบโดยเปลี่ยน batch size, sequence length และ embedding dimension เพราะโค้ดที่ถูกต้องแค่กับขนาดเดียวอาจพังในตอน inference ได้
Dataset และการตั้งค่าพื้นฐาน
- เป้าหมายการอิมพลีเมนต์คือ Llama ของ Meta AI ในเวอร์ชันที่ย่อขนาดลงมาก และข้อมูลเทรนคือ TinyShakespeare
- Llama เทรนด้วย token 1.4T ตัว แต่ที่นี่ใช้ TinyShakespeare ขนาดประมาณ 1.11 ล้านอักขระ
- Llama ต้นฉบับใช้ tokenizer แบบ byte-pair encoding ของ SentencePiece แต่การอิมพลีเมนต์นี้ใช้ tokenizer ระดับอักขระ แบบเรียบง่าย
- vocabulary size คือ 65
- dataset มีขนาดเล็ก จึงไม่ได้ optimize วิธีจัดเก็บในหน่วยความจำเป็นพิเศษ
- จัดการการตั้งค่าโมเดล เช่น
vocab_size,batch_size,context_window,d_modelด้วย dictionary ชื่อMASTER_CONFIG- จุดประสงค์คือเพื่อลดค่าคงที่และ magic number ทำให้โค้ดอ่านง่ายขึ้น
- ฟังก์ชัน
get_batchesแบ่งข้อมูลเป็น train 80%, val 10%, test 10% และสร้างอินพุตxกับ labelyที่เลื่อนไปข้างหลังหนึ่งตัวอักษรจากจุดเริ่มต้นแบบสุ่ม
ตรวจสอบการคอมไพล์และการเทรนด้วยโมเดลพื้นฐาน
- โมเดลแรกคือ
SimpleBrokenModelซึ่งประกอบด้วย embedding และ feed-forward network แบบง่ายnn.EmbeddingLinearReLULinear
- ในการอิมพลีเมนต์เปเปอร์ คำว่าโมเดล “ทำงานได้” ต้องผ่านเงื่อนไขทั้งสองอย่าง
- คอมไพล์ได้: tensor shape ตรงกันระหว่างเลเยอร์
- เทรนได้: loss ลดลงจริง
- ฟังก์ชัน
evaluate_lossสุ่ม batch 10 ครั้งจาก split train และ val แล้วคำนวณ loss เฉลี่ย - หลังเทรน 1000 epochs
SimpleBrokenModelมี validation loss อยู่ราว 3.94 แทบไม่ลดจาก cross-entropy เริ่มต้นที่ 4.17 - สาเหตุคือส่งค่าที่ผ่าน softmax แล้วเข้าไปใน
F.cross_entropyF.cross_entropyของ PyTorch รับ logits ที่ยังไม่ถูก normalize โดยตรงSimpleModelที่เอา softmax ออก ลด validation loss ลงได้ถึงราว 2.51
- จากนั้นเพิ่มฟังก์ชัน
generateเพื่อดูตัวอักษรที่โมเดลสร้างโดยตรง และโมเดลพื้นฐานแม้ยังไม่สมบูรณ์ แต่ก็เข้าสู่สถานะที่ validation loss ลดลงแล้ว
องค์ประกอบของ Llama 1: RMSNorm
- เมื่อเทียบกับ Transformer ดั้งเดิม Llama ใช้การปรับสถาปัตยกรรมหลัก 3 อย่าง
- RMSNorm pre-normalization
- Rotary embeddings
- SwiGLU activation function
- Transformer ดั้งเดิมใช้ BatchNormalization แต่ Llama ใช้ RMSNorm ซึ่ง scale vector ด้วย variance โดยไม่ทำ centering
- ขณะที่ Transformer ดั้งเดิมใช้ post-normalization โดยนำ normalization ไปใช้กับเอาต์พุตของ attention layer แต่ Llama ใช้ pre-normalization โดยนำไปใช้กับอินพุตก่อน
RMSNormที่อิมพลีเมนต์สมมติว่า input shape เป็น(batch, seq_len, d_model)- ผลลัพธ์ของ RMSNorm ทดสอบด้วยคุณสมบัติที่ว่า layer norm จะเท่ากับรากที่สองของจำนวน element ในเลเยอร์
assert- row-wise comparison
torch.allclose
SimpleModel_RMSที่เพิ่ม RMSNorm เข้าไปในโมเดลพื้นฐาน ลด validation loss ลงเล็กน้อยถึงราว 2.5015
องค์ประกอบของ Llama 2: RoPE และ causal mask
- RoPE เป็นวิธี positional encoding สำหรับ Transformer โดยแทนตำแหน่ง token ด้วยการหมุน embedding
get_rotary_matrixสร้างเมทริกซ์การหมุนตามตำแหน่งสำหรับ context window และ embedding dimension- การอิมพลีเมนต์ RoPE ทดสอบด้วยคุณสมบัติต่อไปนี้
- inner product ของเวกเตอร์สองตัวที่ถูกหมุนที่ตำแหน่ง
m,nต้องตรงกับการหมุนตามตำแหน่งสัมพัทธ์n-m
- inner product ของเวกเตอร์สองตัวที่ถูกหมุนที่ตำแหน่ง
RoPEAttentionHeadสร้างw_q,w_k,w_vใช้ RoPE rotation กับ query และ key แล้วใช้F.scaled_dot_product_attention- ต้องระวังความต่างของ tensor shape ระหว่างตอนเทรนกับตอน inference
- ตอนเทรนมักตรงกับค่าตั้งต้น เช่น
(config['batch_size'], config['context_window'], config['d_model']) - ตอน inference อาจต้องจัดการตัวอย่างเดียว เช่น
(1, 1, config['d_model']) - ภายใน
forwardควร index โดยอิง shape ที่ได้จากอินพุต ไม่ใช่ค่าการตั้งค่าของโมเดล
- ตอนเทรนมักตรงกับค่าตั้งต้น เช่น
- โมเดลที่เพิ่ม RoPE multi-head attention โดยไม่มี causal mask มี validation loss ลดฮวบถึง 0.1623 แต่ผลลัพธ์การ generate ไม่ดี เช่น
OOOO...,IIII... - เมื่อตรวจ attention map พบว่าทุกตำแหน่งอ้างอิงถึงทุกตำแหน่ง และเกิด ข้อมูลรั่วไหล จากการมองเห็น token ในอนาคตระหว่างทำนาย token ถัดไป
- เมื่อเปลี่ยนเป็น
RoPEMaskedAttentionHeadที่ใช้is_causal=TrueกับF.scaled_dot_product_attentionค่า upper triangular attention ที่สอดคล้องกับอนาคตแทบกลายเป็น 0 - หลังใช้ causal mask validation loss เป็น 2.0815 และเมื่อเทรนนานขึ้นก็ลดลงถึง 1.8985
องค์ประกอบของ Llama 3: SwiGLU และการซ้อนบล็อก
- Llama เปลี่ยน nonlinearity แบบ ReLU เป็น SwiGLU activation function
SwiGLUที่อิมพลีเมนต์คือ Swish-gated linear unit และใช้การแปลง linear สองชุดกับพารามิเตอร์betaที่เรียนรู้ได้- RopeModel ที่ใส่ SwiGLU ในส่วน feed-forward มีพารามิเตอร์ 592,706 ตัว และ validation loss อยู่ราว 1.8963
- ต่อมาสร้าง
LlamaBlockเพื่อรวมองค์ประกอบต่อไปนี้เป็นหนึ่งบล็อก- RMSNorm pre-normalization
- masked RoPE multi-head attention
- residual connection
- RMSNorm pre-normalization
- SwiGLU feed-forward
- residual connection
- โมเดล
Llamaสุดท้ายตั้งค่าn_layers=4และซ้อนLlamaBlock4 บล็อกด้วยnn.Sequentialที่อิงOrderedDict - โมเดลสุดท้ายมีพารามิเตอร์ 2,370,246 ตัว และผลการเทรนเป็นดังนี้
- หลังเทรน 4-layer ช่วงแรก validation loss 1.5532
- หลังเทรนเพิ่มเป็น 10,000 epochs validation loss 1.1479
- หลังเทรนเพิ่มเติม validation loss 0.9997
- loss ของหนึ่ง batch ใน test split คือ 1.2358
ผลลัพธ์การ generate และรายการตรวจสอบการดีบัก
- โมเดลสุดท้ายสร้างชื่อ การขึ้นบรรทัด และชิ้นส่วนคำที่คล้ายรูปแบบ Shakespeare ได้ แต่คุณภาพประโยคจริงยังจำกัด
- cross-entropy loss สามารถทำให้เข้าใจได้ในมุมมองการเลือก token
- loss เริ่มต้น 4.17 ใกล้เคียงกับการสุ่มเลือกจาก vocabulary size 65
- loss 1.08 ตีความได้ว่าเทียบเท่ากับการสุ่มเลือกจาก token ประมาณ 2.9 ตัว
- ตรวจสอบการไหลของ gradient ด้วยฟังก์ชัน
show_grads- คำนวณสัดส่วนของ gradient ที่มีค่าสัมบูรณ์เล็กในแต่ละพารามิเตอร์
- หาก gradient ของพารามิเตอร์ส่วนใหญ่ไม่ได้ใกล้ 0 แสดงว่าการไหลยังอยู่ในสภาพดี
- Llama ต้นฉบับใช้ learning schedule แบบ Cosine Annealing แต่ในการอิมพลีเมนต์นี้ ผลการทดลองแย่กว่า
- ในการทดลอง Cosine Annealing แม้ใช้ tolerance ต่ำมาก attention bias ก็แทบไม่ได้รับสัญญาณ และยังไม่แน่ชัดว่าเพราะเหตุใด ดังนั้นในการอิมพลีเมนต์จริง การเริ่มแบบเรียบง่ายจึงปลอดภัยกว่า
1 ความคิดเห็น
ความคิดเห็นบน Hacker News
ดูเหมือนว่า การใช้งาน SwiGLU มีบั๊ก: ในเปเปอร์อ้างอิง beta ของ feed-forward network เป็นค่าคงที่ ไม่ใช่ค่าที่เรียนรู้ได้ และกำหนดไว้เป็น
FFnSwiGLU = Swish1...อ้างอิงสมการที่ 6 ใน https://arxiv.org/pdf/2002.05202.pdf
ใน implementation อย่างเป็นทางการของ llama ก็มีการตัด beta แบบค่าคงที่ออกไปแล้ว: https://github.com/facebookresearch/llama/blob/main/llama/mo...
ดูจากบรรทัด
"feedforward.1.beta', 0.0"ใน log ของบล็อกแล้ว beta เสื่อมลงเป็น 0 ระหว่างการเทรน แต่เดิมควรเป็นค่าคงที่ 1เครือข่ายมักปรับตัวเข้ากับการเปลี่ยนแปลงได้ไม่ว่าจะตั้งใจหรือไม่ และหลังเทรนแล้วสถาปัตยกรรมหลายแบบที่ดัดแปลงมาก็อาจทำงานคล้ายกัน จึงมีบางกรณีที่ไม่ชัดเจนว่าจำเป็นต้องตรงกับต้นฉบับเป๊ะหรือไม่
วิธีหนึ่งในการหาข้อผิดพลาดแบบนี้คือ เทียบค่า output ให้ตรงกับ reference implementation แบบเป๊ะ ๆ แม้จะใช้ weight แบบสุ่มเหมือนโมเดล tiny-random ของ HuggingFace ค่า output ก็ต้องตรงกันทุกประการ และถ้าไม่ตรงก็เป็นสัญญาณว่ามีบั๊ก
อย่างไรก็ตาม วิธีนี้ใช้ได้ดีกับบั๊กที่เกิดระหว่าง inference เท่านั้น ส่วนปัญหาที่เกิดเฉพาะระหว่างการประมวลผลข้อมูล, optimizer หรือการเทรนจะจับได้ยากกว่า
ส่วนตัวคิดว่าเป็นเพราะคุณสมบัติแบบ autoregressive และคล้าย ODE แต่ก็ยังไม่มั่นใจถึงขั้นฟันธง
งานนี้ยอดเยี่ยม แต่
SimpleBrokenModelและSimpleModelช่วงแรกมี การคำนวณที่สูญเปล่า อยู่พอสมควร ลำดับคือembedding 65 -> 128,linear 128 -> 128,ReLU,linear 128 -> 65โดยระหว่างสองเลเยอร์แรกไม่มี non-linearity และทั้งคู่เป็น linear ดังนั้น linear layer ตัวที่สองจึงแทบไม่มีประโยชน์สุดท้ายโมเดลนี้เทียบได้กับ MLP แบบ hidden layer เดียวแบบคลาสสิก และถ้าคิดตาม FLOPS ก็เท่ากับสูญเสียการคำนวณ
128*128=16kจากทั้งหมด128*128+65*128=24kembedding layer เป็นโครงสร้างพิเศษที่แปลง token index เป็น embedding vector จึงน่าจะลบออกไม่ได้
โดยรวมแสดงหลักการพื้นฐานได้ดีมาก โดยเฉพาะประโยคที่ว่า “ใช้
.shapeอย่างเคร่งครัดราวกับศาสนาassertกับplt.imshowคือเพื่อนของคุณ” นั้นชอบมาก และควร assert precondition/postcondition ของ shape เสมอก็สงสัยเหมือนกันว่า
bearหรือtypeguardรองรับการตรวจแบบนี้ผ่าน decorator หรือไม่แต่ส่วนที่บอกว่า “เลือกโมเดลที่เล็ก เรียบง่าย และเร็ว แล้วทำ helper สำหรับประเมินเชิงคุณภาพ” ผมคิดว่าน่าจะหมายถึง การประเมินเชิงปริมาณ มากกว่า เพราะจะได้มี baseline เป็นตัวเลขสำหรับเทียบกับเทคนิคขั้นสูงกว่า
คำแนะนำให้ implement องค์ประกอบของเปเปอร์ทีละอย่างก็ควรแม่นยำกว่านี้ด้วย โดยปกติเปเปอร์มักลองเปลี่ยนหลายอย่างพร้อมกัน แล้วใช้ ablation study แสดง contribution ของแต่ละองค์ประกอบ ดังนั้นผมคิดว่าควรเริ่มจากการเปลี่ยนสถาปัตยกรรมหลัก แล้วทำตามลำดับผลกระทบที่มากใน ablation study โดยรักษา dependency และประเมินทุกการเปลี่ยนแปลงแบบ atomic จะดีกว่า
bearหรือtypeguardบางส่วนสามารถยัดเข้าไปตรง ๆ ใน type annotation ของ Python ได้ด้วย https://peps.python.org/pep-0646/เช่น แสดง shape รายแกนไว้ใน type ด้วยรูปแบบอย่าง
ndarray[float, Dim1, *Shape]และ overload shape ที่คืนค่าตามค่าaxisได้bear/typeguardถึงอย่างนั้น Python ก็ดูยากที่จะดีเท่า Julia ระบบ type ของ Julia ทำให้รับประกันว่าขนาดเมทริกซ์เข้ากันได้ง่ายกว่ามาก
อยากรู้ว่า หลักการในการใช้ SwiGLU แทน ReLU คืออะไร ไม่แน่ใจว่าผู้เขียนแค่ลอง non-linear function ที่เป็นไปได้ทั้งหมด หรือมีเหตุผลที่ลึกกว่านั้น
bearblog กำลังโดน DDoS อยู่ จึงฝาก repository ไว้: https://github.com/bkitano/llama-from-scratch
ในฐานะคนที่กำลังเรียน AI เลยลองสรุปคำศัพท์ที่อยู่ในบทความแบบคร่าว ๆ ดู โทเค็นคือ identifier แบบจำนวนเต็มที่แทนชิ้นส่วนของข้อความ และใน LLM จะใช้โดยจับกลุ่มชิ้นส่วนตัวอักษรที่พบบ่อยภายในขนาด vocabulary ที่จำกัด
loss function คือค่าที่วัดความต่างระหว่างผลทำนายกับคำตอบที่ถูกต้อง ยิ่งต่ำยิ่งดี PyTorch เป็นไลบรารีสำหรับจัดการ tensor และ neural network ส่วน tensor คืออาร์เรย์ตัวเลขหลายมิติที่ครอบคลุมทั้ง scalar, vector และ matrix
neural network คือโครงสร้างการเชื่อมต่อของ neuron ที่มี weight และ bias ส่วน linear layer คือโครงสร้างเรียบง่ายที่ input และ output ทั้งหมดเชื่อมถึงกัน ReLU เป็น activation function แบบ
Math.max(0, x)และเพราะถ้าวางซ้อนแต่ linear layer สุดท้ายก็จะเทียบเท่ากับ linear function ตัวเดียว จึงต้องใส่ nonlinearity เพื่อเพิ่มความสามารถในการเรียนรู้gradient คือค่าการเปลี่ยนแปลงเชิงตัวเลขที่คำนวณระหว่างการเรียนรู้เพื่อทำให้โมเดลแม่นยำขึ้น ส่วน batch normalization เป็นวิธีช่วยการเรียนรู้ด้วยการปรับตัวเลขที่ไหลผ่านไปมา positional encoding บอกตำแหน่งสัมพัทธ์ของ token ต่าง ๆ ในรูป vector
operator
@ของ Python เป็น alias ของ__matmul__ใช้สำหรับการคูณ matrix epoch คือการเรียนรู้จาก dataset ทั้งหมดครบหนึ่งรอบ และ batch คือจำนวนข้อมูลที่ใส่เข้าไปพร้อมกันก่อนอัปเดต parameterattention เป็นแกนหลักที่ทำให้ LLM ทำงาน โดยประมวลผล input token แบบขนานเพื่อสร้าง tensor ระหว่างทาง แล้วนำไปใช้สร้าง output token
เช่น
writที่พบร่วมกันในwriting,written,writerอาจกลายเป็น token หนึ่งได้ และwriterอาจถูก tokenize เป็นwritกับerembedding คือขั้นตอนที่แปลง token เหล่านี้ให้เป็น representation เชิงตัวเลขเฉพาะตัว
ถ้ามี implementation เดิมของโมเดลกับ checkpoint อยู่ วิธีที่มีประสิทธิภาพที่สุดในการตรวจว่าการ implement ของตัวเองถูกต้องหรือไม่ คือ โหลด checkpoint นั้นมาแล้วเปรียบเทียบ output
ถ้า output ไม่ตรง ส่วนใหญ่แปลว่า implement รายละเอียดบางอย่างผิด และสามารถไล่ตรวจแต่ละ layer อย่างเป็นระบบเพื่อหาความต่างจริง ๆ ได้ ระหว่างนั้นอาจเจอจุดแปลก ๆ ใน implementation เดิมด้วย
เรื่องนี้เป็นเรื่องของตัวโมเดลเอง ส่วนการฝึกเป็นอีกแกนหนึ่งต่างหาก ถึงอย่างนั้นถ้าปรับ hyperparameter ให้ใกล้เคียงกันพอสมควร เมื่อ implementation ของโมเดลถูกต้อง โดยรวมมักจะออกมาโอเค
ทั้งวิธีอ่าน paper และเนื้อหาของ paper นั้นดีมาก และขอแนะนำ Makemore series ของ Karpathy ด้วย
คำแนะนำสรุปดีมาก และผมคิดว่าคำแนะนำให้ assert shape ของ tensor ใช้ได้กับไลบรารี linear algebra ทั่วไปทุกตัว เวลาเขียนโค้ด linear algebra ที่ซับซ้อน การไปทีละขั้นเล็ก ๆ และเขียนโค้ดแบบป้องกันตัวเองเป็นเรื่องสำคัญมาก
การเขียนโปรแกรม linear algebra ในภาษากระแสหลักเป็นเรื่องแย่มาก เพราะไม่มี การตรวจ shape ตอน compile time shape ของ tensor ควรเป็นส่วนหนึ่งของ type และถ้าพยายามคูณ
3x4กับ3x4โดยไม่ transpose ก็ควร compile ไม่ผ่านตั้งแต่แรกการรันการคำนวณยาว ๆ แล้วไปพังที่ operation ที่มิติไม่ตรงกันนี่เลวร้ายจริง ๆ
ผมคิดว่า tensor ของ PyTorch ก็ควรมีการกำหนด type ของ device แบบ static ด้วย ตอนนี้ถ้าพยายามคูณ tensor ในหน่วยความจำ CPU กับ tensor ในหน่วยความจำ GPU จะเกิด runtime error