• สำหรับ งานเฉพาะทาง ที่ LLM อเนกประสงค์อาจเกินความจำเป็น การ fine-tune Llama-2 เองสามารถปรับปรุงคุณภาพ ต้นทุน และ latency พร้อมกันได้ด้วยโมเดลที่เล็กและถูกกว่า
  • หลัง fine-tuning แล้ว Llama-2 13B มีความแม่นยำด้าน การแทนค่าฟังก์ชันของ ViGGO เพิ่มจาก 58%→98%, การสร้าง SQL จาก 42%→89%, และ GSM8k จาก 28%→47%
  • ในงานที่รูปแบบผลลัพธ์สำคัญ เช่น ViGGO และการสร้าง SQL โมเดล Llama-2 ขนาดเล็กให้ผลลัพธ์ดีกว่า GPT-4 แต่ในด้านการให้เหตุผลทางคณิตศาสตร์ยังไม่ถึงระดับ GPT-4
  • การทดลองใช้สคริปต์บน Ray Train, Ray Data, DeepSpeed และ Accelerate โดยฝึก 7B·13B บน 16xA10G และ 70B บน 32xA10G
  • หัวใจของการเพิ่มประสิทธิภาพไม่ใช่ขนาดโมเดล แต่เป็น คุณภาพข้อมูลและ pipeline การประเมินผล และต้องเปรียบเทียบ trade-off ด้านต้นทุน·คุณภาพระหว่าง prompt engineering กับ fine-tuning ตามแต่ละงาน

ผลของ fine-tuning ที่เห็นจากสามงาน

  • โมเดลอเนกประสงค์ขนาดใหญ่ เช่น GPT-4 และ Claude-2 มีประโยชน์สำหรับการทำ prototype อย่างรวดเร็ว แต่สำหรับความต้องการที่มีขอบเขตแคบ เช่น การสรุปและจัดหมวดหมู่ support ticket อาจเกินความจำเป็นทั้งในแง่ต้นทุนและประสิทธิภาพ
  • การทดลองเปรียบเทียบระดับการปรับปรุงเมื่อทำ full-parameter fine-tuning โมเดล Llama-2 ให้เหมาะกับงานแบบใช้งานจริงสามประเภท
    • ViGGO: ดึงการแทนค่าแบบฟังก์ชันจากข้อความไม่มีโครงสร้าง
    • SQL-create-context: สร้าง SQL จากภาษาธรรมชาติและบริบท CREATE TABLE
    • GSM8k: แก้โจทย์คณิตศาสตร์ระดับประถม
  • การเปลี่ยนแปลงความแม่นยำสำหรับ Llama-2 13B เป็นดังนี้
    • การแทนค่าฟังก์ชันของ ViGGO: 58% → 98%
    • การสร้าง SQL: 42% → 89%
    • GSM8k: 28% → 47%
  • ใน ViGGO และการสร้าง SQL โมเดล Llama-2 ขนาดเล็กให้ผลลัพธ์ดีกว่า GPT-4 ส่วนงานให้เหตุผลทางคณิตศาสตร์อย่าง GSM8k แม้หลัง fine-tuning แล้วก็ยังไม่ถึงประสิทธิภาพของ GPT-4

วิธี fine-tuning และโครงสร้างพื้นฐานการฝึก

  • ทั้งสามงานใช้ full-parameter fine-tuning แบบมาตรฐาน
    • ฝึกด้วยวิธีทำนาย token ถัดไป
    • พารามิเตอร์ทั้งหมดของโมเดลเป็นเป้าหมายของการอัปเดต gradient
    • LoRA หรือวิธีตรึง transformer block บางส่วนถูกตัดออกจากขอบเขตการทดลอง
  • สคริปต์การทดลองสร้างบน Ray Train, Ray Data, DeepSpeed, Accelerate
    • รองรับการรัน Llama-2 7B, 13B, 70B
    • TorchTrainer ของ Ray Train กระจาย training loop ไปยังหลาย worker process และทรัพยากร GPU
    • Ray Train จัดการ data sharding และแต่ละ worker เข้าถึงชิ้นข้อมูลที่ถูกจัดสรรผ่าน session.get_dataset_shard("train"), session.get_dataset_shard("valid")
  • การ shard โมเดลจัดการด้วย DeepSpeed ZeRO stage 3 และ optimizer state offloading
    • เนื่องจากชิ้นส่วนโมเดลถูกแบ่งไปยังหลาย worker เมื่อต้องเข้าถึงโมเดลทั้งหมด เช่น การบันทึก checkpoint ต้องคลี่โมเดลด้วย accelerator.unwrap_model(model)
  • ทรัพยากรคำนวณมีดังนี้
    • 7B·13B: 16xA10G
    • 70B: 32xA10G, อินสแตนซ์ g5.48xlarge จำนวน 4 เครื่อง
    • เมื่อใช้ Ray การทำ full-parameter fine-tuning ไม่จำเป็นต้องใช้ A100 เสมอไป
  • ฝึกสูงสุด 10 epoch และเลือก checkpoint ที่มี perplexity ต่ำที่สุดบน validation set

ใช้ special token เพื่อกำหนดโครงสร้าง input·output

  • ข้อมูล fine-tuning แสดงโครงสร้างงานด้วย special token แทน prompt แบบคำสั่ง
    • ตัวอย่าง: <START_Q>{question}<END_Q><START_A>{answer}<END_A>
  • special token ช่วยให้โมเดลแยกช่วง input กับ output และเรียนรู้จุดหยุดของ output ได้ชัดเจน
    • ในตัวอย่างกำหนด <END_A> เป็น stopping token เพื่อให้หยุดสร้าง output เมื่องานเสร็จ
  • โดยค่าเริ่มต้น Llama tokenizer ส่งออก token ID จำนวน 32,000 ตัว
    • เมื่อเพิ่ม special token สี่ตัว จะส่งออก ID 32,004 ตัว
    • ID ใหม่ถูกกำหนดในลักษณะเช่น <START_Q> เป็น 32000, <END_Q> เป็น 32001
  • สคริปต์เพิ่ม special token ด้วย tokenizer.add_tokens(special_tokens, special_tokens=True) และสร้างพารามิเตอร์ฝึกใหม่ด้วย model.resize_token_embeddings(len(tokenizer))

ViGGO: แปลงข้อความไม่มีโครงสร้างเป็นการแทนค่าแบบฟังก์ชัน

  • เดิม ViGGO เป็นชุดข้อมูลภาษาอังกฤษที่แปลงการแทนค่าแบบฟังก์ชันตาม attribute-value เป็นข้อความภาษาธรรมชาติ แต่ในการทดลองนี้กลับทิศทางเพื่อแปลงข้อความไม่มีโครงสร้างเป็น การแทนค่าแบบฟังก์ชันที่มีโครงสร้าง
    • โดเมนคือความคิดเห็นเกี่ยวกับวิดีโอเกม
    • การแทนค่าผลลัพธ์สามารถใช้สำหรับการทำ indexing และแอปพลิเคชันต่อเนื่องได้
  • โมเดลต้องสร้างฟังก์ชันและค่า attribute ที่ตรงกับประโยค
    • ตัวเลือกฟังก์ชันประกอบด้วย inform, request, give_opinion, confirm, verify_attribute, suggest, request_explanation, recommend, request_attribute
    • ตัวเลือก attribute ประกอบด้วย name, release_year, esrb, genres, platforms, available_on_steam, has_linux_release, has_mac_release, specifier, rating, player_perspective, has_multiplayer, developer, exp_release_date เป็นต้น
  • input ตัวอย่าง What's a really fast-paced game with multiplayer that you like to play? มี output ที่คาดหวังคือ request(has_multiplayer[yes], specifier[fast-paced])
  • โมเดลทั่วไปทำตามรูปแบบ output ที่ต้องการได้ไม่ดี และมีปัญหาที่เวลาในการประมวลผล input มากกว่าเวลาสร้าง output เพราะ input context ยาว
  • งานนี้เน้น การรู้จำรูปแบบ และความเข้าใจภาษาพื้นฐานมากกว่าการให้เหตุผลเชิงตรรกะที่ซับซ้อน
    • เป็น grounded task ที่ข้อเท็จจริงที่จำเป็นทั้งหมดอยู่ใน input
    • การที่ few-shot prompt ช่วยได้ถูกใช้เป็นสัญญาณว่าโมเดล Llama-2 ขนาดเล็กก็สามารถปรับปรุงได้ด้วย fine-tuning

การประเมินและผลลัพธ์ของ ViGGO

  • การประเมินไม่ได้ใช้เพียงการจับคู่ตัวอักษรแบบสมบูรณ์
    • ตรวจสอบว่าฟังก์ชันใน output ถูกต้องหรือไม่
    • ตรวจสอบว่า type ของ attribute ถูกต้องหรือไม่
    • ตรวจสอบว่า attribute ในฟังก์ชันเป็นไปตาม ลำดับความสำคัญ ที่กำหนดไว้หรือไม่
  • สำหรับโมเดล instruction-following เช่น GPT และ Llama-2-chat เนื่องจากระบุ rule เรื่องลำดับ attribute ไว้ใน prompt จึงประเมินภายใต้เงื่อนไขว่าต้องทำตาม rule นั้น
  • เพื่อเพิ่มความเร็วในการประเมิน ใช้ batch inference API ของ Ray ร่วมกับ Aviary ของ Anyscale
    • เชื่อมการสร้างผลลัพธ์ของ LLM กับ post-processing และกระจายไปยังหลายเครื่อง
  • โมเดล 7B และ 13B มีความแม่นยำเพิ่มขึ้นอย่างมากหลัง fine-tuning
    • GPT-4 มีความแม่นยำลดลงมากเมื่อรวมลำดับความสำคัญของ attribute ในการประเมิน
    • โมเดลที่ fine-tune แล้วทำตามลำดับความสำคัญเสมอ และความแม่นยำไม่เปลี่ยนเมื่อเพิ่มข้อจำกัดนี้
  • ผลลัพธ์ของ ViGGO แสดงว่า fine-tuning อาจเป็นวิธีที่เสถียรและมีประสิทธิภาพสำหรับงานที่ต้องการ รูปแบบมีโครงสร้าง
    • ไม่ใช่แค่การจับรูปแบบด้วย regex ง่าย ๆ หรือทำให้ตรงรูปแบบ JSON แต่ต้องตัดสินใจว่าจะรวม argument ใดและต้องรักษาลำดับของ argument ที่รวมเข้ามาด้วย
    • เนื่องจากเป็นผลลัพธ์จากโมเดล 7B·13B ต้นทุนการให้บริการอาจต่ำกว่าการเรียก endpoint ของ GPT-4

การสร้าง SQL: สร้าง query จากภาษาธรรมชาติและบริบทตาราง

  • งานสร้าง SQL คือการรับคำถามภาษาธรรมชาติและคำสั่ง SQL CREATE TABLE เป็น input แล้วสร้าง SQL query ที่รันได้
  • ชุดข้อมูลที่ใช้ b-mc2/sql-create-context เป็นชุดข้อมูลบน Hugging Face ที่รวม WikiSQL กับ Spider
    • แต่ละ data point ประกอบด้วยคำถามภาษาธรรมชาติ, คำสั่ง SQL CREATE TABLE, และ SQL query ที่สอดคล้องกัน
    • ทั้งหมดมี 78,577 data point
  • ชุดข้อมูลมีปัญหาใน SQL คำตอบที่ถูกต้อง
    • ใน CREATE TABLE attribute แบบจำนวนเต็มถูกระบุเป็น VARCHAR แต่ใน SQL query มักถูกประมวลผลเหมือนจำนวนเต็ม
    • จึงลบ SQL query ทั้งหมดที่ตั้งสมมติฐานว่า attribute เป็นจำนวนเต็ม ลดชุดข้อมูลจากประมาณ 70k เหลือ 45k
  • งานนี้ก็เหมาะกับ fine-tuning เพราะเป็นปัญหาการแปลงภาษาธรรมชาติเป็น representation มีโครงสร้างอย่าง SQL
    • ต่างจาก ViGGO ตรงที่อาจมี SQL ได้หลายแบบที่ให้ผลการรันถูกต้อง จึงคลุมเครือมากกว่า

การประเมินและผลลัพธ์ของ SQL

  • การประเมินการสร้าง SQL ไม่เหมาะกับการเปรียบเทียบ string ง่าย ๆ
    • การเปรียบเทียบระดับตัวอักษรอาจสร้าง false negative จำนวนมาก
    • การเปรียบเทียบ AST ก็อาจไวต่อองค์ประกอบอย่างลำดับชื่อตัวแปร
    • วิธีที่เชื่อถือได้ที่สุดคือรันโค้ดบนชุดข้อมูลปลอมแล้วเปรียบเทียบว่า output เหมือนกันหรือไม่
  • ในการทดลองใช้ GPT-3.5 endpoint ของ OpenAI สร้างตารางปลอมสำหรับ unit test ของตัวอย่างหลายร้อยรายการ
    • GPT-3.5 ดูคำถาม, schema ตาราง และคำตอบ แล้วสร้างตารางปลอม 10 data point
    • ใช้ sqlglot.executor.execute รัน SQL คำตอบกับ SQL ของโมเดล แล้วเปรียบเทียบผลลัพธ์
  • เพื่อตรวจสอบคุณภาพตารางข้อมูลที่ GPT-3.5 สร้างขึ้น รัน SQL คำตอบก่อน
    • หากตารางผลลัพธ์ว่างหรือมีความยาวเท่ากับตารางต้นฉบับ จะทิ้งตัวอย่างนั้น
    • ในกระบวนการนี้ ตารางข้อมูลที่ GPT สร้างขึ้นประมาณ 50% ถูกกรองออก
  • Llama-2 7B และ 13B ที่ fine-tune แล้วให้ประสิทธิภาพสูงกว่า 70B-chat และ GPT-4
    • ข้อผิดพลาดที่พบบ่อยของโมเดล Llama chat คือไม่ใส่ SQL ไว้ในแท็ก <SQL> อย่างสม่ำเสมอตามคำสั่งใน prompt
    • ปัญหานี้พบในโมเดล chat 7B·13B บ่อยกว่า 70B
  • คำถามภาษาธรรมชาติบางส่วนในชุดข้อมูล SQL ไม่ใช่ภาษาอังกฤษที่สมบูรณ์ และ noise เหล่านี้อาจส่งผลต่อผลลัพธ์ของ GPT-4
    • โมเดลที่ fine-tune แล้วปรับตัวเข้ากับลักษณะเฉพาะของชุดข้อมูลได้อย่างรวดเร็ว

GSM8k: การให้เหตุผลทางคณิตศาสตร์ที่ยากกว่าการเรียนรู้โครงสร้าง

  • GSM8k เป็น benchmark เชิงวิชาการมาตรฐานสำหรับประเมินความสามารถด้านการให้เหตุผลและความเข้าใจทางคณิตศาสตร์
  • หากสองงานก่อนหน้าส่วนใหญ่เป็นการเรียนรู้โครงสร้าง GSM8k เป็นงานที่ตรวจสอบว่าสามารถปรับปรุง กระบวนการให้เหตุผล ของโมเดลเพื่อแก้โจทย์คณิตศาสตร์ได้มากเพียงใด
  • โจทย์ตัวอย่างเป็นลักษณะถามยอดขายรวมเมื่อขายได้ 48 ชิ้นในเดือนเมษายนและขายได้ครึ่งหนึ่งของจำนวนนั้นในเดือนพฤษภาคม โดยคำตอบจบด้วยรูปแบบ #### 72 พร้อมการคำนวณระหว่างทาง
  • LLM ในปัจจุบันจำเป็นต้องสร้างกระบวนการคิดเป็นส่วนหนึ่งของ output เพื่อให้การสร้าง token ถัดไปอิงกับกระบวนการเชิงตรรกะ แทนที่จะคำนวณคำตอบสุดท้ายภายในแล้วตอบออกมาทันที
  • งานนี้ต้องการ chain of thought เชิงตรรกะตั้งแต่ premise ผ่านข้อสรุประหว่างทางไปจนถึงคำตอบสุดท้าย ไม่ใช่เพียงการคำนวณง่าย ๆ

วิธีประเมินและ baseline ของ GSM8k

  • การประเมินต้องมีวิธีดึงคำตอบสุดท้ายจาก output ของโมเดลอย่างเสถียร
  • โมเดลภาษาทั่วไปอาจไม่ทำตามรูปแบบ output ที่ต้องการอย่างสม่ำเสมอ ทำให้ประเมินอัตโนมัติได้ยาก
    • เพื่อแก้ปัญหานี้ ใช้ OpenAI function calling API
    • ให้ gpt-3.5-turbo-0613 เรียกฟังก์ชัน report_answer เพื่อดึงคำตอบจำนวนเต็มสุดท้ายจากผลลัพธ์ที่โมเดลอื่นสร้าง
    • ตัวอย่างเช่น แม้โมเดลตอบว่า “The answer is four” ก็สามารถ parse เป็น 4 ได้
  • วิธีนี้ตรวจสอบความถูกต้องบนคำตอบของชุดข้อมูลแล้ว แต่มีข้อเสียคือต้องเสียค่า token ของ OpenAI ในการประเมิน
  • โมเดลที่ fine-tune แล้วเรียนรู้รูปแบบคำตอบเป้าหมายได้อย่างรวดเร็ว ทำให้โครงสร้าง output คาดเดาได้แม้คำตอบจะผิด
    • การประเมินโมเดลที่ fine-tune แล้วใช้ regular expression #### {answer} จึงหลีกเลี่ยง post-processing ผ่าน OpenAI endpoint ได้
  • baseline มีดังนี้
    • ผลลัพธ์ 8-shot prompting ของ base pre-trained model ที่เผยแพร่ใน paper
    • เทมเพลตหลายแบบที่ผ่าน prompt engineering สำหรับ Llama-2 รุ่น chat-tuned ที่ Meta ฝึกด้วย RLHF ให้เป็น assistant อเนกประสงค์

ผลลัพธ์ GSM8k และ fine-tuning สองขั้นตอน

  • การ fine-tune base model เพิ่มประสิทธิภาพ GSM8k อย่างสม่ำเสมอ แต่ไม่ได้ให้ผลลัพธ์ที่ดีกว่า chat-tuned model อย่างชัดเจนเสมอไป
    • chat model มีความแม่นยำสูงกว่า base model เพราะอาจได้รับการฝึกด้วยตัวอย่างคณิตศาสตร์ในกระบวนการ chat-tuning
  • การใส่ prompt ให้โมเดลที่ fine-tune แล้วไม่ได้ให้ผลลัพธ์ดีกว่า base model เสมอไป
    • ตัวอย่างเช่น Llama-2-70B-chat อาจต่ำกว่า base model ที่ใส่ 8-shot example prompt
    • โมเดลที่ fine-tune แล้วดีกว่า 8-shot prompted base model อย่างสม่ำเสมอ
  • ในแง่ต้นทุนการให้บริการ โมเดลที่ fine-tune แล้วอาจได้เปรียบ
    • วิธีแบบ prompt มีต้นทุน token ของ prompt ในทุก request
    • โมเดลที่ fine-tune แล้วโดยพื้นฐานจะคิดต้นทุนเกือบเฉพาะจำนวน token ของคำถาม
  • ข้อมูลฝึก GSM8k มีประมาณ 8k รายการ ซึ่งค่อนข้างเล็ก จึงประเมินว่าเป็นเรื่องยากที่จะดึงศักยภาพของ Llama-13B ออกมาได้เต็มที่
  • วิธีสองขั้นตอนที่ fine-tune Llama-13B base model ด้วย MathQA ก่อน แล้ว fine-tune อีกครั้งด้วย GSM8k ให้การปรับปรุงเพิ่มเติม
    • fine-tuning ด้วย GSM8k อย่างเดียวปรับปรุงจาก base 10%p
    • fine-tuning สองขั้นตอนด้วย MathQA แล้ว GSM8k ปรับปรุงเพิ่มอีก 10%p จากผล fine-tuning แรก รวมเป็น 20%p จาก base
  • MathQA ประกอบด้วยคู่คำถาม/คำตอบ 30,000 คู่ แต่มี noise มากกว่า GSM8k และโครงสร้างต่างกัน
    • คุณภาพคำตอบต่ำ และคำตอบสุดท้ายเป็นรูปแบบ multiple choice
    • ถึงอย่างนั้น fine-tuning สองขั้นตอนก็มีประสิทธิผลในการใช้ MathQA เพื่อปรับปรุงผลลัพธ์สุดท้ายของ GSM8k

เกณฑ์ที่ควรพิจารณาในการใช้งานจริง

  • โมเดลแบบปิดอย่าง GPT-4 และ Claude-2 แข็งแกร่งในการทำ prototype และตรวจสอบคุณค่าเบื้องต้น แต่ไม่ได้เพียงพอเสมอไปสำหรับการดำเนินงานแอป LLM ใน production
  • การ fine-tune LLM สำหรับ niche task อาจมีคุณค่าไม่เฉพาะด้าน privacy แต่รวมถึง latency, ต้นทุน, คุณภาพ
    • ในตัวอย่าง ViGGO และ SQL ได้ผลลัพธ์ที่ดีกว่า GPT-4 ในด้านคุณภาพด้วย
  • จุดโฟกัสสำคัญในการ fine-tuning คือการรวบรวมข้อมูลและสร้าง pipeline การประเมินผล มากกว่ารายละเอียดการติดตั้ง infrastructure
    • pipeline การประเมินผลเป็นพื้นฐานสำหรับเปรียบเทียบ trade-off ของหลายแนวทางให้สอดคล้องกับความต้องการทางธุรกิจ
  • การทดลองดำเนินการโดยใช้แพลตฟอร์ม fine-tuning และ serving ของ Anyscale และ Anyscale Endpoints
  • กระบวนการเดียวกันถูกจัดทำเป็นโซลูชัน fine-tuning และ serving ของ Anyscale บน Ray เพื่อให้ทำซ้ำได้ด้วยข้อมูลของตนเองและคลาวด์ของตนเอง

ยังไม่มีความคิดเห็น

ยังไม่มีความคิดเห็น