กระบวนการพัฒนา LLM ตอนที่ 8 - เทคโนโลยี self-attention ที่เรียนรู้ได้
(gilesthomas.com)- ใน Transformer แบบ decoder-only สไตล์ GPT นั้น self-attention ที่เรียนรู้ได้ จะคำนวณว่าแต่ละโทเค็นควรให้ความสนใจกับโทเค็นใดในอินพุตก่อนหน้า เพื่อสร้างเวกเตอร์บริบท
- แกนสำคัญคือ scaled dot product attention ซึ่งใช้เมทริกซ์ที่เรียนรู้ได้ 3 ตัว
Wq,Wk,Wvสำหรับส่งอินพุตเอมเบดดิงไปยังพื้นที่ query, key และ value - เมทริกซ์อินพุต
Xจะถูกแปลงเป็นQ=XWq,K=XWk,V=XWvจากนั้นนำΩ=QKᵀมาหารด้วย√cแล้วใช้ softmax แบบรายแถวเพื่อให้ได้ค่าน้ำหนัก attentionA - เวกเตอร์บริบทถูกสร้างขึ้นด้วยการคูณเมทริกซ์ครั้งเดียวเป็น
C=AVและการคำนวณทั้งหมดสามารถใช้กับทุกโทเค็นได้ด้วย การคูณเมทริกซ์ 5 ครั้ง และการ transpose 1 ครั้ง - ขั้นตอนนี้ก้าวต่อจากตัวอย่างแบบง่ายที่ทำ dot product ระหว่างอินพุตเอมเบดดิงโดยตรง ไปสู่ attention ที่ฝึกได้ ซึ่งสามารถเขียนด้วย PyTorch
nn.Moduleและnn.Linear
ตำแหน่งของ self-attention ในลำดับการประมวลผลของ LLM
- LLM ที่อิง Transformer แบบ decoder-only สไตล์ GPT มีโครงสร้างที่มองโทเค็นทั้งหมดจนถึงปัจจุบันแล้วทำนายโทเค็นถัดไป
- ลำดับการประมวลผลคือแบ่งสตริงเป็นโทเค็น แปลงแต่ละโทเค็นเป็น token embedding แล้วบวก positional embedding ที่แทนข้อมูลตำแหน่งเพื่อสร้างอินพุตเอมเบดดิง
- self-attention จะสร้างรายการ คะแนน attention ที่บอกว่าสำหรับแต่ละอินพุตเอมเบดดิงควรสนใจโทเค็นอื่นมากน้อยเพียงใด
- ในประโยคตัวอย่าง
"the fat cat sat on the mat"เมื่อดู"cat"นั้น"fat"อาจมีความสำคัญ - แต่เมื่อดู
"mat"ความสำคัญของ"fat"อาจลดลงเมื่อเทียบกัน
- ในประโยคตัวอย่าง
- คะแนน attention จะผ่าน softmax กลายเป็น ค่าน้ำหนัก attention ที่มีผลรวมเท่ากับ 1 แล้วใช้ค่าน้ำหนักนี้ทำผลรวมถ่วงน้ำหนักของอินพุตเอมเบดดิงเพื่อสร้างเวกเตอร์บริบท
- เวกเตอร์บริบทถูกใช้เป็นเวกเตอร์ที่แทนความหมายของแต่ละโทเค็นภายในบริบทของอินพุตทั้งหมด
เป้าหมายของ self-attention ที่เรียนรู้ได้
- ในขั้นก่อนหน้านี้ใช้ self-attention แบบตัวอย่างง่าย ที่คำนวณ dot product ระหว่างอินพุตเอมเบดดิงโดยตรง
- เป้าหมายของขั้นนี้คือการสร้าง กลไก attention ที่เรียนรู้ได้ ซึ่งสามารถสร้างคะแนน attention จากเวกเตอร์อินพุตได้
- หัวข้อ 3.4 ของ Build a Large Language Model (from Scratch) โดย Sebastian Raschka นำสิ่งนี้ไปใช้งานด้วย scaled dot product attention
- จุดเน้นไม่ได้อยู่ที่เหตุผลว่าทำไมโครงสร้างนี้จึงมีประสิทธิภาพ แต่เน้นว่ามันทำงานผ่านการคำนวณแบบใด
เมทริกซ์ Query, Key, Value และการฉายไปยังสเปซ
- กำหนดให้ความยาวลำดับอินพุตเป็น
n, มิติของอินพุตเอมเบดดิงเป็นdและมิติของเวกเตอร์บริบทเป็นc - ลำดับอินพุตเอมเบดดิงแสดงเป็น
x1, x2, x3, ... xnและอินพุตเอมเบดดิงแต่ละตัวเป็นเวกเตอร์มิติd - กำหนดเมทริกซ์น้ำหนักที่เรียนรู้ได้ 3 ตัว
- query weights matrix:
Wq - key weights matrix:
Wk - value weights matrix:
Wv
- query weights matrix:
- แต่ละเมทริกซ์มีขนาด
d×cและใช้ฉายเวกเตอร์อินพุตมิติdไปยังสเปซมิติc - การส่งเวกเตอร์อินพุต
xmไปยังสเปซ query คำนวณได้เป็นqm=xmWq - สเปซ key และ value ก็ฉายอินพุตเอมเบดดิงไปยังสเปซมิติ
cคนละแบบด้วยวิธีเดียวกัน
วิธีมองเมทริกซ์ในฐานะการฉาย
- เมทริกซ์สามารถใช้ทำ การแปลงเชิงเรขาคณิต เช่น การหมุนจุด
- เมทริกซ์จัตุรัสทำการแปลงภายในมิติเดิม ส่วนเมทริกซ์ไม่จัตุรัสสามารถส่งเวกเตอร์ไปยังสเปซที่มีมิติแตกต่างออกไปได้
- ตัวอย่างเช่น เมทริกซ์
3×2สามารถแปลงเวกเตอร์ 3 มิติให้เป็นเวกเตอร์ 2 มิติได้ - ในกราฟิก 3D เมทริกซ์ frustum ที่ใช้แปลงจุด 3D ไปเป็นจุดบนหน้าจอ 2D ก็เป็นตัวอย่างของการฉายลักษณะนี้
- self-attention จะส่งอินพุตเอมเบดดิงไปยัง สเปซการฉายที่แตกต่างกัน 3 แบบ คือ query, key และ value ก่อนนำเวกเตอร์ที่ฉายแล้วไปคำนวณต่อ
- เนื่องจากเมทริกซ์ฉายเหล่านี้ถูกเรียนรู้ระหว่างการฝึก จึงเกิด ความเป็นทางอ้อม ที่ไม่มีใน dot product attention แบบตรงไปตรงมา
การคำนวณคะแนน attention
- เมื่อพิจารณาอินพุต
xmคะแนน attention สำหรับอินพุตอื่นxpจะนิยามเป็น dot product ของการฉายแบบ query และการฉายแบบ key - สูตรคำนวณมีดังนี้
qm=xmWqkp=xpWkωm,p=qm·kp
- สามารถคำนวณสิ่งนี้ด้วยลูปสำหรับทุกอินพุตได้ แต่ถ้าใช้การคูณเมทริกซ์จะคำนวณได้ทั้งหมดในครั้งเดียว
- หากแทนอินพุตเอมเบดดิงทั้งหมดด้วยเมทริกซ์
Xเมทริกซ์นี้จะมีขนาดn×d - เมทริกซ์ key คำนวณได้รวดเดียวเป็น
K=XWk- ผลลัพธ์
Kมีขนาดn×c - แต่ละแถวคือเวกเตอร์ที่ได้จากการฉายอินพุตเอมเบดดิงนั้นไปยังสเปซ key
- ผลลัพธ์
- เมทริกซ์ query ก็คำนวณแบบเดียวกันเป็น
Q=XWq - dot product ระหว่างทุก query กับทุก key ได้จาก
QKᵀQมีขนาดn×cKᵀมีขนาดc×n- ผลลัพธ์
Ωมีขนาดn×n
Ωm,pคือ คะแนน attention ที่บอกว่าขณะสร้างเวกเตอร์บริบทของxmควรให้ความสนใจกับxpมากเพียงใด
การสเกลและการทำให้เป็นมาตรฐานด้วย softmax
- คะแนน attention จะผ่าน softmax เช่นเดียวกับตัวอย่างก่อนหน้า เพื่อเปลี่ยนเป็นค่าน้ำหนักที่มีผลรวมเท่ากับ 1
- softmax จะขยายค่าที่มากและลดค่าที่น้อยลง พร้อมทั้งปรับให้ผลรวมของทั้งรายการเท่ากับ 1
- ใน LLM จริง
dและcอาจมีขนาดระดับหลักพัน ดังนั้นหากใช้ softmax ตรง ๆ อาจทำให้เกิด gradient ที่เล็กมาก - ในกรณีนี้ softmax อาจทำงานคล้าย “step function”
- ตีความได้ว่า ค่าที่มากที่สุดจะครอบงำ ส่วนค่าอื่น ๆ จะเล็กมาก
- เพื่อบรรเทาปัญหานี้ จึงนำคะแนน attention มาหารด้วยรากที่สองของมิติของสเปซฉาย
cก่อนค่อยใช้ softmax - ในรูปเมทริกซ์คือ
A=softmax(Ω/√c, axis=1)
axis=1เป็นสัญกรณ์แบบ PyTorch หมายถึงใช้ softmax แบบรายแถว- ผลลัพธ์
Aคือคะแนน attention ที่ผ่านการทำให้เป็นมาตรฐานแล้ว หรือก็คือ เมทริกซ์ค่าน้ำหนัก attention
การสร้างเวกเตอร์บริบท
- การฉายไปยังสเปซ value คำนวณได้เป็น
V=XWv Aคือเมทริกซ์ค่าน้ำหนัก attention ขนาดn×nAm,pคือค่าน้ำหนัก attention ที่ใช้กับอินพุตpขณะสร้างเวกเตอร์บริบทของxm
Vมีขนาดn×cและแต่ละแถวคือเวกเตอร์ที่ได้จากการฉายอินพุตเอมเบดดิงไปยังสเปซ value- เมทริกซ์เวกเตอร์บริบทคำนวณได้เป็น
C=AV- ผลลัพธ์
Cมีขนาดn×c - แถวที่
mของCคือเวกเตอร์บริบทสำหรับอินพุตxm
- ผลลัพธ์
- การคำนวณนี้ทำงานเทียบเท่ากับการนำเวกเตอร์ value ของแต่ละโทเค็นมาคูณด้วยค่าน้ำหนัก attention แล้วบวกเข้าด้วยกันสำหรับแต่ละโทเค็น แต่ทำได้ด้วยการคูณเมทริกซ์เพียงครั้งเดียว
สรุปการคำนวณทั้งหมด
- เมทริกซ์อินพุต
Xเก็บอินพุตเอมเบดดิงของลำดับโทเค็น และมีขนาดn×d - ใช้เมทริกซ์ที่เรียนรู้ได้ 3 ตัวเพื่อฉายอินพุตไปยังสเปซ query, key และ value ตามลำดับ
Q=XWqK=XWkV=XWv
- คำนวณคะแนน attention จาก dot product ของ query และ key
Ω=QKᵀ
- สเกลคะแนนแล้วใช้ softmax แบบรายแถวเพื่อสร้างค่าน้ำหนัก attention
A=softmax(Ω/√c, axis=1)
- คูณการฉายแบบ value กับค่าน้ำหนัก attention เพื่อสร้างเวกเตอร์บริบท
C=AV
- กลไก self-attention ทั้งหมดสามารถสร้างเวกเตอร์บริบทของโทเค็นอินพุตทั้งหมดได้ด้วย การคูณเมทริกซ์ 5 ครั้ง และการ transpose 1 ครั้ง
การใช้งานด้วย PyTorch และขั้นตอนถัดไป
- หัวข้อ 3.4 ของหนังสือได้นำการคำนวณข้างต้นไปเขียนเป็นโค้ด PyTorch และสร้างซับคลาส
nn.Moduleแบบง่ายที่ทำการคำนวณเมทริกซ์ชุดเดียวกัน - เวอร์ชันแรกใช้วัตถุ
nn.Parameterทั่วไปสำหรับเมทริกซ์น้ำหนักทั้งสาม - เวอร์ชันที่สองใช้
nn.Linearเพื่อให้การฝึกมีประสิทธิภาพมากขึ้น - หลังจากนี้จะมี 2 หัวข้อถัดไป
- causal self-attention: วิธีที่เมื่อดูโทเค็นหนึ่ง จะไม่สนใจโทเค็นที่อยู่ถัดไปในอนาคต
- multi-head attention: เกริ่นว่าเป็นหัวข้อที่ไม่ได้ซับซ้อนอย่างที่คิดไว้ตอนแรก
- การประมวลผลแบบแบตช์ยังเป็นประเด็นที่ต้องพิจารณาแยกต่างหาก
- แม้จะเป็นลำดับอินพุตเดี่ยว ก็ยังใช้เมทริกซ์คะแนน attention อยู่แล้ว
- หากต้องประมวลผลหลายลำดับอินพุตแบบขนาน อาจต้องใช้เทนเซอร์ที่มีอันดับสูงกว่าเมทริกซ์
- บทความถัดไปจะต่อด้วย Writing an LLM from scratch, part 9 -- causal attention
ยังไม่มีความคิดเห็น