• ใน Transformer แบบ decoder-only สไตล์ GPT นั้น self-attention ที่เรียนรู้ได้ จะคำนวณว่าแต่ละโทเค็นควรให้ความสนใจกับโทเค็นใดในอินพุตก่อนหน้า เพื่อสร้างเวกเตอร์บริบท
  • แกนสำคัญคือ scaled dot product attention ซึ่งใช้เมทริกซ์ที่เรียนรู้ได้ 3 ตัว Wq, Wk, Wv สำหรับส่งอินพุตเอมเบดดิงไปยังพื้นที่ query, key และ value
  • เมทริกซ์อินพุต X จะถูกแปลงเป็น Q=XWq, K=XWk, V=XWv จากนั้นนำ Ω=QKᵀ มาหารด้วย √c แล้วใช้ softmax แบบรายแถวเพื่อให้ได้ค่าน้ำหนัก attention A
  • เวกเตอร์บริบทถูกสร้างขึ้นด้วยการคูณเมทริกซ์ครั้งเดียวเป็น C=AV และการคำนวณทั้งหมดสามารถใช้กับทุกโทเค็นได้ด้วย การคูณเมทริกซ์ 5 ครั้ง และการ transpose 1 ครั้ง
  • ขั้นตอนนี้ก้าวต่อจากตัวอย่างแบบง่ายที่ทำ dot product ระหว่างอินพุตเอมเบดดิงโดยตรง ไปสู่ attention ที่ฝึกได้ ซึ่งสามารถเขียนด้วย PyTorch nn.Module และ nn.Linear

ตำแหน่งของ self-attention ในลำดับการประมวลผลของ LLM

  • LLM ที่อิง Transformer แบบ decoder-only สไตล์ GPT มีโครงสร้างที่มองโทเค็นทั้งหมดจนถึงปัจจุบันแล้วทำนายโทเค็นถัดไป
  • ลำดับการประมวลผลคือแบ่งสตริงเป็นโทเค็น แปลงแต่ละโทเค็นเป็น token embedding แล้วบวก positional embedding ที่แทนข้อมูลตำแหน่งเพื่อสร้างอินพุตเอมเบดดิง
  • self-attention จะสร้างรายการ คะแนน attention ที่บอกว่าสำหรับแต่ละอินพุตเอมเบดดิงควรสนใจโทเค็นอื่นมากน้อยเพียงใด
    • ในประโยคตัวอย่าง "the fat cat sat on the mat" เมื่อดู "cat" นั้น "fat" อาจมีความสำคัญ
    • แต่เมื่อดู "mat" ความสำคัญของ "fat" อาจลดลงเมื่อเทียบกัน
  • คะแนน attention จะผ่าน softmax กลายเป็น ค่าน้ำหนัก attention ที่มีผลรวมเท่ากับ 1 แล้วใช้ค่าน้ำหนักนี้ทำผลรวมถ่วงน้ำหนักของอินพุตเอมเบดดิงเพื่อสร้างเวกเตอร์บริบท
  • เวกเตอร์บริบทถูกใช้เป็นเวกเตอร์ที่แทนความหมายของแต่ละโทเค็นภายในบริบทของอินพุตทั้งหมด

เป้าหมายของ self-attention ที่เรียนรู้ได้

  • ในขั้นก่อนหน้านี้ใช้ self-attention แบบตัวอย่างง่าย ที่คำนวณ dot product ระหว่างอินพุตเอมเบดดิงโดยตรง
  • เป้าหมายของขั้นนี้คือการสร้าง กลไก attention ที่เรียนรู้ได้ ซึ่งสามารถสร้างคะแนน attention จากเวกเตอร์อินพุตได้
  • หัวข้อ 3.4 ของ Build a Large Language Model (from Scratch) โดย Sebastian Raschka นำสิ่งนี้ไปใช้งานด้วย scaled dot product attention
  • จุดเน้นไม่ได้อยู่ที่เหตุผลว่าทำไมโครงสร้างนี้จึงมีประสิทธิภาพ แต่เน้นว่ามันทำงานผ่านการคำนวณแบบใด

เมทริกซ์ Query, Key, Value และการฉายไปยังสเปซ

  • กำหนดให้ความยาวลำดับอินพุตเป็น n, มิติของอินพุตเอมเบดดิงเป็น d และมิติของเวกเตอร์บริบทเป็น c
  • ลำดับอินพุตเอมเบดดิงแสดงเป็น x1, x2, x3, ... xn และอินพุตเอมเบดดิงแต่ละตัวเป็นเวกเตอร์มิติ d
  • กำหนดเมทริกซ์น้ำหนักที่เรียนรู้ได้ 3 ตัว
    • query weights matrix: Wq
    • key weights matrix: Wk
    • value weights matrix: Wv
  • แต่ละเมทริกซ์มีขนาด d×c และใช้ฉายเวกเตอร์อินพุตมิติ d ไปยังสเปซมิติ c
  • การส่งเวกเตอร์อินพุต xm ไปยังสเปซ query คำนวณได้เป็น qm=xmWq
  • สเปซ key และ value ก็ฉายอินพุตเอมเบดดิงไปยังสเปซมิติ c คนละแบบด้วยวิธีเดียวกัน

วิธีมองเมทริกซ์ในฐานะการฉาย

  • เมทริกซ์สามารถใช้ทำ การแปลงเชิงเรขาคณิต เช่น การหมุนจุด
  • เมทริกซ์จัตุรัสทำการแปลงภายในมิติเดิม ส่วนเมทริกซ์ไม่จัตุรัสสามารถส่งเวกเตอร์ไปยังสเปซที่มีมิติแตกต่างออกไปได้
  • ตัวอย่างเช่น เมทริกซ์ 3×2 สามารถแปลงเวกเตอร์ 3 มิติให้เป็นเวกเตอร์ 2 มิติได้
  • ในกราฟิก 3D เมทริกซ์ frustum ที่ใช้แปลงจุด 3D ไปเป็นจุดบนหน้าจอ 2D ก็เป็นตัวอย่างของการฉายลักษณะนี้
  • self-attention จะส่งอินพุตเอมเบดดิงไปยัง สเปซการฉายที่แตกต่างกัน 3 แบบ คือ query, key และ value ก่อนนำเวกเตอร์ที่ฉายแล้วไปคำนวณต่อ
  • เนื่องจากเมทริกซ์ฉายเหล่านี้ถูกเรียนรู้ระหว่างการฝึก จึงเกิด ความเป็นทางอ้อม ที่ไม่มีใน dot product attention แบบตรงไปตรงมา

การคำนวณคะแนน attention

  • เมื่อพิจารณาอินพุต xm คะแนน attention สำหรับอินพุตอื่น xp จะนิยามเป็น dot product ของการฉายแบบ query และการฉายแบบ key
  • สูตรคำนวณมีดังนี้
    • qm=xmWq
    • kp=xpWk
    • ωm,p=qm·kp
  • สามารถคำนวณสิ่งนี้ด้วยลูปสำหรับทุกอินพุตได้ แต่ถ้าใช้การคูณเมทริกซ์จะคำนวณได้ทั้งหมดในครั้งเดียว
  • หากแทนอินพุตเอมเบดดิงทั้งหมดด้วยเมทริกซ์ X เมทริกซ์นี้จะมีขนาด n×d
  • เมทริกซ์ key คำนวณได้รวดเดียวเป็น K=XWk
    • ผลลัพธ์ K มีขนาด n×c
    • แต่ละแถวคือเวกเตอร์ที่ได้จากการฉายอินพุตเอมเบดดิงนั้นไปยังสเปซ key
  • เมทริกซ์ query ก็คำนวณแบบเดียวกันเป็น Q=XWq
  • dot product ระหว่างทุก query กับทุก key ได้จาก QKᵀ
    • Q มีขนาด n×c
    • Kᵀ มีขนาด c×n
    • ผลลัพธ์ Ω มีขนาด n×n
  • Ωm,p คือ คะแนน attention ที่บอกว่าขณะสร้างเวกเตอร์บริบทของ xm ควรให้ความสนใจกับ xp มากเพียงใด

การสเกลและการทำให้เป็นมาตรฐานด้วย softmax

  • คะแนน attention จะผ่าน softmax เช่นเดียวกับตัวอย่างก่อนหน้า เพื่อเปลี่ยนเป็นค่าน้ำหนักที่มีผลรวมเท่ากับ 1
  • softmax จะขยายค่าที่มากและลดค่าที่น้อยลง พร้อมทั้งปรับให้ผลรวมของทั้งรายการเท่ากับ 1
  • ใน LLM จริง d และ c อาจมีขนาดระดับหลักพัน ดังนั้นหากใช้ softmax ตรง ๆ อาจทำให้เกิด gradient ที่เล็กมาก
  • ในกรณีนี้ softmax อาจทำงานคล้าย “step function”
    • ตีความได้ว่า ค่าที่มากที่สุดจะครอบงำ ส่วนค่าอื่น ๆ จะเล็กมาก
  • เพื่อบรรเทาปัญหานี้ จึงนำคะแนน attention มาหารด้วยรากที่สองของมิติของสเปซฉาย c ก่อนค่อยใช้ softmax
  • ในรูปเมทริกซ์คือ
    • A=softmax(Ω/√c, axis=1)
  • axis=1 เป็นสัญกรณ์แบบ PyTorch หมายถึงใช้ softmax แบบรายแถว
  • ผลลัพธ์ A คือคะแนน attention ที่ผ่านการทำให้เป็นมาตรฐานแล้ว หรือก็คือ เมทริกซ์ค่าน้ำหนัก attention

การสร้างเวกเตอร์บริบท

  • การฉายไปยังสเปซ value คำนวณได้เป็น V=XWv
  • A คือเมทริกซ์ค่าน้ำหนัก attention ขนาด n×n
    • Am,p คือค่าน้ำหนัก attention ที่ใช้กับอินพุต p ขณะสร้างเวกเตอร์บริบทของ xm
  • V มีขนาด n×c และแต่ละแถวคือเวกเตอร์ที่ได้จากการฉายอินพุตเอมเบดดิงไปยังสเปซ value
  • เมทริกซ์เวกเตอร์บริบทคำนวณได้เป็น C=AV
    • ผลลัพธ์ C มีขนาด n×c
    • แถวที่ m ของ C คือเวกเตอร์บริบทสำหรับอินพุต xm
  • การคำนวณนี้ทำงานเทียบเท่ากับการนำเวกเตอร์ value ของแต่ละโทเค็นมาคูณด้วยค่าน้ำหนัก attention แล้วบวกเข้าด้วยกันสำหรับแต่ละโทเค็น แต่ทำได้ด้วยการคูณเมทริกซ์เพียงครั้งเดียว

สรุปการคำนวณทั้งหมด

  • เมทริกซ์อินพุต X เก็บอินพุตเอมเบดดิงของลำดับโทเค็น และมีขนาด n×d
  • ใช้เมทริกซ์ที่เรียนรู้ได้ 3 ตัวเพื่อฉายอินพุตไปยังสเปซ query, key และ value ตามลำดับ
    • Q=XWq
    • K=XWk
    • V=XWv
  • คำนวณคะแนน attention จาก dot product ของ query และ key
    • Ω=QKᵀ
  • สเกลคะแนนแล้วใช้ softmax แบบรายแถวเพื่อสร้างค่าน้ำหนัก attention
    • A=softmax(Ω/√c, axis=1)
  • คูณการฉายแบบ value กับค่าน้ำหนัก attention เพื่อสร้างเวกเตอร์บริบท
    • C=AV
  • กลไก self-attention ทั้งหมดสามารถสร้างเวกเตอร์บริบทของโทเค็นอินพุตทั้งหมดได้ด้วย การคูณเมทริกซ์ 5 ครั้ง และการ transpose 1 ครั้ง

การใช้งานด้วย PyTorch และขั้นตอนถัดไป

  • หัวข้อ 3.4 ของหนังสือได้นำการคำนวณข้างต้นไปเขียนเป็นโค้ด PyTorch และสร้างซับคลาส nn.Module แบบง่ายที่ทำการคำนวณเมทริกซ์ชุดเดียวกัน
  • เวอร์ชันแรกใช้วัตถุ nn.Parameter ทั่วไปสำหรับเมทริกซ์น้ำหนักทั้งสาม
  • เวอร์ชันที่สองใช้ nn.Linear เพื่อให้การฝึกมีประสิทธิภาพมากขึ้น
  • หลังจากนี้จะมี 2 หัวข้อถัดไป
    • causal self-attention: วิธีที่เมื่อดูโทเค็นหนึ่ง จะไม่สนใจโทเค็นที่อยู่ถัดไปในอนาคต
    • multi-head attention: เกริ่นว่าเป็นหัวข้อที่ไม่ได้ซับซ้อนอย่างที่คิดไว้ตอนแรก
  • การประมวลผลแบบแบตช์ยังเป็นประเด็นที่ต้องพิจารณาแยกต่างหาก
    • แม้จะเป็นลำดับอินพุตเดี่ยว ก็ยังใช้เมทริกซ์คะแนน attention อยู่แล้ว
    • หากต้องประมวลผลหลายลำดับอินพุตแบบขนาน อาจต้องใช้เทนเซอร์ที่มีอันดับสูงกว่าเมทริกซ์
  • บทความถัดไปจะต่อด้วย Writing an LLM from scratch, part 9 -- causal attention

ยังไม่มีความคิดเห็น

ยังไม่มีความคิดเห็น