โอเพนซอร์ส Triton kernel fusion ที่เพิ่มความเร็วการอนุมานของ Qwen3-TTS ได้สูงสุด 5 เท่า
(github.com/newgrit1004)สวัสดีครับ ผมขอมาแชร์ไลบรารี Triton kernel fusion ที่สร้างขึ้นเพื่อแก้คอขวดในการอนุมานของโมเดล Qwen3-TTS 1.7B และทำให้ความเร็วเพิ่มขึ้นราว 5 เท่า
1. ทำไมถึงสร้างสิ่งนี้? (ที่มา)
คนที่เคยนำ TTS audio ไปใช้จริงในการทำงานน่าจะทราบดีว่า โมเดลเชิงความน่าจะเป็น (Stochastic) อย่าง Qwen3-TTS จะให้ผลลัพธ์ที่สร้างออกมาแตกต่างกันในทุกครั้ง (เช่น จังหวะการพูด โทนเสียง ฯลฯ)
สุดท้ายแล้ว ในการใช้งานจริงจำเป็นต้องใช้กลยุทธ์การสร้างหลายชุด โดยสร้างตัวเลือกเสียงหลายแบบอย่างรวดเร็วแล้วคัดอันที่ฟังเป็นธรรมชาติที่สุด แต่ด้วยความเร็วเดิม เวิร์กโฟลว์ทำงานค่อนข้างอึดอัดมาก จึงลงมือทำ optimization เอง ผ่านการปรับปรุงครั้งนี้ จากเดิมที่สร้างได้ 1 ชิ้นในเวลาหนึ่งช่วง ตอนนี้สามารถสร้างตัวเลือกได้ 5 แบบในเวลาเท่าเดิม
2. สร้างอย่างไร? (Claude Code + การทดสอบหนักมาก)
พูดตามตรง ผมรู้ถึงความทรงพลังของ OpenAI Triton ซึ่งเป็นไลบรารีสำหรับ kernel optimization อยู่แล้ว แต่ไม่เคยเขียนโค้ด kernel ด้วยตัวเองมาก่อน ดังนั้นโค้ด kernel ในโปรเจกต์นี้ส่วนใหญ่จึงเขียนขึ้นโดยอาศัยความช่วยเหลือจาก Claude Code
อย่างไรก็ตาม เพื่อชดเชยประสบการณ์ด้านการใช้งาน Triton ที่ยังไม่มากพอของผม และเพื่อรับประกันความน่าเชื่อถือของโมเดล 100% ผมจึงทุ่มพลังทั้งหมดไปกับ การทดสอบอย่างเข้มข้นจริงจัง แทนที่จะทุ่มไปกับการเขียนโค้ด
- เขียน unit test จำนวน 90 รายการ เพื่อรับประกันว่าผลลัพธ์ทางคณิตศาสตร์จะเหมือนกับโมเดลต้นฉบับทุกประการ
- ทำค่า Cosine Similarity > 0.997 ได้ครบทั้งที่เลเยอร์ checkpoint หลักและผลลัพธ์สุดท้ายทั้งหมด
3. ประเด็นด้านวิศวกรรมและผลงาน
ได้รับแรงบันดาลใจจาก Liger Kernel ของ LinkedIn โดยได้ทำ fusion ให้กับ 4 โอเปอเรชันที่เป็นคอขวดระหว่างการอนุมาน (RMSNorm, M-RoPE, Norm+Residual, SwiGLU) ด้วย Triton kernel
[เบนช์มาร์กประสิทธิภาพ - อ้างอิง RTX 5090]
- Base (PyTorch): 3,902 ms
- Hybrid (Faster+Triton): 919 ms (~เร็วขึ้น 4.7x)
(※ โหมด Hybrid คือผลลัพธ์จากการนำ Triton kernel fusion ครั้งนี้ไปซ้อนบน faster-qwen3-tts ที่อิง CUDA Graph)
4. ปิดท้าย
ตอนนี้ได้ทดสอบเสร็จสิ้นแล้วเฉพาะบนสภาพแวดล้อม RTX 5090 ซึ่งเป็นอุปกรณ์ส่วนตัวของผมเท่านั้น หากใครกำลังใช้งานอุปกรณ์แบบเซิร์ฟเวอร์ (A100, H100) หรืออุปกรณ์อื่นอย่าง RTX 4090 แล้วลองรันดู พร้อมส่งฟีดแบ็กผ่าน GitHub หรือคอมเมนต์มาได้ จะช่วยได้มากจริง ๆ ครับ
ขอบคุณที่อ่านมาจนจบครับ!
ยังไม่มีความคิดเห็น