PyTorch ตายแล้ว จงทรงพระเจริญ JAX
(neel04.github.io)- เหตุผลที่ PyTorch ก่อให้เกิดการสูญเสียผลิตภาพและทำให้เสียเวลาในการพัฒนา ไม่ใช่เพราะ "ตัวเฟรมเวิร์กแย่" แต่เป็นเพราะมันไม่ได้ถูกออกแบบมาให้เหมาะกับยูสเคสที่กำลังถูกนำไปใช้ในปัจจุบัน
ปรัชญาของ PyTorch
- ปรัชญาของ PyTorch คือความเป็นไดนามิก ดีบักง่าย และเป็น Pythonic
- ขณะที่ TensorFlow 1.x ตั้งใจจะเป็นเฟรมเวิร์กแบบสแตติกแต่มีประสิทธิภาพสูง โดยพึ่งพาคอมไพเลอร์ XLA อย่างหนัก
- นักพัฒนา TensorFlow ตระหนักว่าชุมชนไม่ชอบ API ของ 1.x จึงตัดสินใจใช้ Keras เป็นอินเทอร์เฟซหลัก และลดบทบาทของคอมไพเลอร์ XLA ลง
- PyTorch ยังคงยึดรากเดิมของตนไว้ และต่างจากแนวทางแบบสแตติกและประเมินผลล่าช้าของ TensorFlow โดยใช้แนวทาง "eager execution" ที่มีความไดนามิกมากกว่า ซึ่ง
torch.Tensorจะถูกประเมินผลทันที - เมื่อแนวทางนี้ประสบความสำเร็จ งานวิจัยจำนวนมากก็ย้ายมาใช้ PyTorch
- เมื่อ GPT-3 ปรากฏตัวในปี 2021 ประสิทธิภาพและความสามารถในการสเกลกลายเป็นประเด็นสำคัญ
- PyTorch ตอบสนองความต้องการเหล่านี้ได้ดีในระดับหนึ่ง แต่เพราะมันไม่ได้ถูกออกแบบโดยคำนึงถึงปรัชญานี้ หนี้ทางเทคนิคจึงค่อย ๆ สะสมและฐานรากเริ่มสั่นคลอน
- นักพัฒนา PyTorch ไม่ต้องการจุดประนีประนอมใด ๆ และเลือกจะเดินสองเส้นทางพร้อมกัน
- ใช้คอมไพเลอร์ XLA เป็นแบ็กเอนด์พื้นฐานที่มีประสิทธิภาพและเสถียรภาพสูง
- สร้างสแต็ก
torch.compileเพื่อให้ผู้ใช้มีอิสระในการเรียกใช้คอมไพเลอร์เมื่อจำเป็น
- การไม่มีแผนกลยุทธ์ระยะยาวเป็นปัญหาร้ายแรง
- PyTorch ไม่อยากยึดมั่นกับปรัชญาแบบเน้นคอมไพเลอร์เป็นศูนย์กลาง (เช่น JAX) แต่ก็มองไม่เห็นทางเลือกที่ดี
- แล้วคู่แข่งแก้ปัญหานี้อย่างไร?
การพัฒนาแบบอิงคอมไพเลอร์ของ JAX
- JAX ใช้ประโยชน์จาก XLA ซึ่งเป็นสแต็กคอมไพเลอร์อันทรงพลังของ TensorFlow
- XLA เป็นคอมไพเลอร์ที่ทรงพลัง แต่ทั้งหมดถูกนามธรรมซ่อนไว้จากผู้ใช้ปลายทาง
- หากฟังก์ชันเป็น pure ก็สามารถใช้ดีคอเรเตอร์
@jax.jitเพื่อทำ JIT compile และทำให้มันใช้งานบน XLA ได้ - XLA จะจัดการทุกอย่างเบื้องหลัง ทั้งการตรวจสอบว่ากราฟที่สร้างขึ้นถูกต้องหรือไม่, GSPMD partitioner ที่จัดการการขนานอัตโนมัติด้วย sharding ใน JAX, การเพิ่มประสิทธิภาพกราฟ, การ fusion ของ operator และ kernel, การจัดตารางเพื่อซ่อน latency, การ overlap การสื่อสารแบบอะซิงโครนัส, การสร้างโค้ดไปยังแบ็กเอนด์อื่นอย่าง triton เป็นต้น
- ตราบใดที่ปฏิบัติตามข้อจำกัดของ JAX, XLA จะจัดการให้โดยอัตโนมัติ
- ตัวอย่างเช่น เวลา parallelize ไม่จำเป็นต้องมี communication primitive อย่าง
torch.distributed.barrier() - การรองรับ DDP ทำได้ด้วยโค้ดที่เรียบง่าย
- แนวทางของ XLA คือการคำนวณจะเป็นไปตาม sharding ดังนั้นหากอาร์เรย์อินพุตถูก shard ตามแกนใดแกนหนึ่ง XLA ก็จะจัดการการคำนวณย่อยต่อจากนั้นให้โดยอัตโนมัติ
- แนวคิดเรื่อง "การพัฒนาแบบอิงคอมไพเลอร์" คล้ายกับวิธีการทำงานของคอมไพเลอร์ Rust
- ข้อจำกัดของ PyTorch
- ไม่พอใจกับการที่นักพัฒนา PyTorch เลือกผสานและพึ่งพาสแต็กคอมไพเลอร์สำหรับฟีเจอร์ใหม่ แทนที่จะรักษาปรัชญาแกนหลักเรื่องความยืดหยุ่นและอิสระไว้
- ตามโรดแมปอย่างเป็นทางการของ PyTorch 2.x มีการระบุแผนระยะยาวอย่างชัดเจนว่าจะผสาน XLA เข้ากับ Torch อย่างสมบูรณ์
- นี่เป็นไอเดียที่แย่มาก มันเหมือนกับการบอกว่าการยัดโค้ด C++ เข้าไปในคอมไพเลอร์ Rust จะให้ประสบการณ์ที่ดีกว่าการใช้ Rust เอง
- Torch ไม่ได้ถูกออกแบบโดยมี XLA เป็นศูนย์กลาง ต่างจาก JAX
- หาก PyTorch ตัดสินใจใช้สแต็กคอมไพเลอร์ที่อิง XLA แล้ว เฟรมเวิร์กในอุดมคติไม่ควรเป็นสิ่งที่ถูกออกแบบและสร้างขึ้นมาโดยเฉพาะรอบสิ่งนั้นหรือ?
- ต่อให้ PyTorch จะเดินตามแนวทาง "multi-backend" ที่เลือกคอมไพเลอร์แบ็กเอนด์ได้ตามต้องการ ก็จะยิ่งทำให้ปัญหาความแตกเป็นเสี่ยงแย่ลง และอาจทำให้ API พังยับขณะพยายามเคารพข้อจำกัดของสแต็กคอมไพเลอร์ทุกตัวหรือไม่?
- ใครก็ตามที่เคยใช้ Torch/XLA บน TPU ต่างก็มีอาการ PTSD หนักกันทั้งนั้น
Multi-Backend ล้มเหลวแล้ว
- PyTorch พยายามทำทุกอย่างพร้อมกันและล้มเหลวอย่างน่าเวทนา
- การตัดสินใจออกแบบแบบ "multi-backend" ทำให้ปัญหานี้แย่ลงแบบทวีคูณ
- ในทางทฤษฎีมันฟังดูเหมือนจะเลือกสแต็กที่ต้องการได้ แต่ในทางปฏิบัติคือความยุ่งเหยิงของ traceback ที่เข้าใจยากและปัญหาความไม่เข้ากันที่พันกันยุ่ง
- ข้อจำกัดข้ามแบ็กเอนด์และการปะทะกับ API ของ PyTorch
- สิ่งที่ยากไม่ใช่แค่การทำให้แบ็กเอนด์เหล่านี้ทำงาน แต่ข้อจำกัดที่แบ็กเอนด์เหล่านี้คาดหวังกลับไม่เข้ากับ API แบบยืดหยุ่นและ Pythonic ของ PyTorch
- มี trade-off ระหว่างการรักษาความสม่ำเสมอของ API กับการทำตามข้อจำกัดของแบ็กเอนด์
- ผลลัพธ์คือเหล่านักพัฒนาพยายามพึ่งการสร้างโค้ดมากขึ้น แทนที่จะผสาน/ยึดกับแบ็กเอนด์ใดแบ็กเอนด์หนึ่งจริง ๆ
- การไร้กลยุทธ์ของ PyTorch
- เพราะ PyTorch ปฏิเสธที่จะยอมรับ trade-off ที่มีความหมาย การตัดสินใจทุกอย่างจึงให้ความรู้สึกเหมือนการประนีประนอม
- ไม่มีทั้งความสม่ำเสมอและไม่มีกลยุทธ์โดยรวม
- ท้ายที่สุดสิ่งนี้สร้างความหงุดหงิดให้ผู้ใช้มาก และทำให้มันดูเหมือนกองรวมของฟีเจอร์ที่ไม่เข้ากัน
- ไม่มีวิธีใดฆ่า ecosystem ได้เร็วไปกว่านี้อีกแล้ว
- ทำไมไม่ควรเดินตามแนวทางของ JAX
- PyTorch ไม่ควรเดินตามแนวทาง "คอมไพเลอร์และแบ็กเอนด์แบบรวมศูนย์" ของ JAX
- เพราะ JAX ถูกออกแบบมาอย่างชัดเจนให้ทำงานร่วมกับ XLA
- การแทนที่ฟรอนต์เอนด์ของ PyTorch ด้วยของ JAX ไม่อาจเป็นกลยุทธ์ได้
- แทบเป็นไปไม่ได้ที่จะออกแบบ API บน XLA ที่ดีกว่า JAX
- ไม่ได้ตำหนินักพัฒนาที่อยากลองไอเดียใหม่และแตกต่าง
- แต่ถ้า PyTorch อยากอยู่รอดในระยะยาว ก็ควรให้ความสำคัญกับการเสริมฐานรากมากกว่าการปล่อยฟีเจอร์ใหม่สุดเท่ที่พังทันทีเมื่อออกนอกเงื่อนไขสอนใช้งานในอุดมคติ
ความแตกเป็นเสี่ยงของ PyTorch และการเขียนโปรแกรมเชิงฟังก์ชันของ JAX
- API เชิงฟังก์ชันของ JAX
- ฟังก์ชันใน JAX ต้องเป็น pure กล่าวคือห้ามมีผลข้างเคียงระดับโกลบอล
- เช่นเดียวกับฟังก์ชันทางคณิตศาสตร์ เมื่อได้รับข้อมูลชุดเดิมก็ต้องคืนผลลัพธ์เดิมเสมอ ไม่ว่าคอนเท็กซ์ของการรันจะเป็นอย่างไร
- ด้วยปรัชญาการออกแบบนี้ ฟังก์ชันของ JAX จึงประกอบรวมกันได้และทำงานร่วมกันได้ดี
- ความซับซ้อนในการพัฒนาลดลง และฟังก์ชันถูกนิยามด้วย signature เฉพาะและงานเชิงรูปธรรมที่กำหนดไว้อย่างชัดเจน
- หาก type ถูกต้อง ก็รับประกันได้ว่าฟังก์ชันจะทำงานได้ทันที
- สิ่งนี้เหมาะกับประเภทงานที่จำเป็นในการคำนวณเชิงวิทยาศาสตร์ โดยเฉพาะ deep learning
- ตัวอย่าง API ของ optax
- ด้วยแนวทางเชิงฟังก์ชัน optax จึงมีสิ่งที่เรียกว่า "chain"
- ซึ่งประกอบด้วยหลายฟังก์ชันที่ถูกนำไปใช้กับ gradient ตามลำดับ
- องค์ประกอบพื้นฐานคือ
GradientTransformation - ทำให้ได้ API ที่ทรงพลังแต่ยังแสดงออกได้ดี
- ตัวอย่างเช่น การ clip gradient, การทำ EMA ของ gradient หรือการรวม optimizer หลายตัวเข้าด้วยกัน กลายเป็นเรื่องง่ายมาก
- ข้อดีของการออกแบบเชิงฟังก์ชัน
- ผลลัพธ์ที่ยอดเยี่ยมอีกอย่างของการออกแบบเชิงฟังก์ชันคือ
vmap - มันย่อมาจาก vectorized map และอธิบายหน้าที่ของมันได้ตรงตัว
- สามารถ map ได้กับทุกอย่าง และตราบใดที่เป็น
vmapXLA ก็จะ fusion และ optimize ให้โดยอัตโนมัติ - เวลาเขียนฟังก์ชันไม่จำเป็นต้องคิดถึง batch dimension
- แค่
vmapโค้ดทั้งหมดก็พอ - นั่นหมายความว่าต้องทำงานพวก ein-* น้อยลง
- การทำความเข้าใจการจัดการเทนเซอร์ 2D/3D จึงเป็นธรรมชาติกว่าและอ่านง่ายกว่ามาก
- เพราะแค่แยกองค์ประกอบแต่ละส่วนออกมาเพื่อใช้เหตุผล ก็ทำให้เขียนโค้ดซับซ้อนที่ทำงานถูกต้องได้ง่ายขึ้น
- ตราบใดที่เคารพข้อจำกัดเรื่องความเป็น pure และมี signature ที่ถูกต้อง ก็จะได้ประโยชน์อื่นทั้งหมดอย่างความสามารถในการประกอบรวมตามมา
- ผลลัพธ์ที่ยอดเยี่ยมอีกอย่างของการออกแบบเชิงฟังก์ชันคือ
- ปัญหาของ ecosystem ของ PyTorch
- ใน torch ไม่ว่าคุณจะใช้สแต็กแบบไหน (
FSDP + multi-node + torch.compileเป็นต้น) ก็มีโอกาสที่บางอย่างจะพังเสมอ - หลายสิ่งต้องทำงานร่วมกันอย่างถูกต้อง และหากองค์ประกอบใดองค์ประกอบหนึ่งล้มเหลว คุณก็ต้องดีบักกันถึงตี 3
- เนื่องจากไม่สามารถทดสอบทุกชุดผสมของฟีเจอร์หลายสิบอย่างที่ PyTorch มีได้ จึงย่อมมีบั๊กที่ไม่ถูกค้นพบระหว่างการพัฒนาอยู่เสมอ
- แทบเป็นไปไม่ได้ที่จะเขียนโค้ดที่ทำงานได้ดีโดยไม่ทุ่มแรงอย่างมาก
- ecosystem ของ torch จึงใหญ่เทอะทะมากและเต็มไปด้วยบั๊ก
- เพราะไม่มี abstraction ร่วมกัน ไลบรารีและเฟรมเวิร์กใหม่ ๆ จึงเกิดขึ้นโดยไม่ได้ถูกออกแบบมาให้เชื่อมต่อกับ "โซลูชัน" อื่น
- จากนั้นก็เสื่อมสภาพอย่างรวดเร็วกลายเป็นความโกลาหลของ dependency และ
requirements.txt - 70-80% ของ GitHub issues หรือการถกเถียงในฟอรัม เกิดขึ้นเพียงเพราะมี error ระหว่างไลบรารีที่ต่างกัน
- แทบไม่มีทางแก้ปัญหานี้ได้
- ใน torch ไม่ว่าคุณจะใช้สแต็กแบบไหน (
- การไม่มีทางออก
- นี่คือปัญหาของ OOP และการออกแบบ
- มีความคิดว่าอ็อบเจ็กต์พื้นฐานที่มีความเป็น PyTorch อย่าง PyTree อาจช่วยสร้างฐานร่วมสำหรับ abstraction ได้
- ก็ไม่อาจหันไปใช้พาราไดม์การเขียนโปรแกรมเชิงฟังก์ชันได้เช่นกัน
- เพราะถ้าทำเช่นนั้น มันจะค่อย ๆ กลายเป็น JAX เวอร์ชันที่ประสิทธิภาพแย่กว่า พร้อมกับทำลาย backward compatibility ของโค้ดเบส torch ที่มีอยู่ทั้งหมด
- PyTorch ดูเหมือนจะพังยับในประเด็นนี้อย่างสิ้นเชิง
ความได้เปรียบด้านการทำซ้ำผลลัพธ์ของ JAX
- การจัดการ seed
- การจัดการ seed ของ PyTorch ไม่ได้ดีในอุดมคติ
- โดยทั่วไปต้องรันโค้ดหลายบรรทัด
- ลืมหรือตั้งค่าผิดได้ง่าย
- JAX บังคับให้สร้าง key อย่างชัดเจนแล้วส่งเข้าไปยังทุกฟังก์ชันที่ต้องใช้ความสุ่ม
- แนวทางนี้ขจัดปัญหาได้หมดจด เพราะ RNG จะถูก seed แบบสแตติกเสมอ
- JAX มี NumPy เวอร์ชันของตัวเอง (
jax.numpy) จึงไม่จำเป็นต้องตั้ง seed แยกต่างหาก - การตัดสินใจด้าน QoL เล็ก ๆ แบบนี้สามารถทำให้ประสบการณ์ผู้ใช้ของทั้งเฟรมเวิร์กดีขึ้นอย่างมาก
- ความสามารถในการพกพา
- หนึ่งในปัญหาใหญ่ที่สุดเวลาทำงานกับโค้ดเบส PyTorch คือการขาดความสามารถในการพกพา
- โค้ดเบสที่เขียนมาสำหรับ CUDA/GPU มักทำงานได้ไม่ดีเมื่อรันบนฮาร์ดแวร์ที่ไม่ใช่ Nvidia เช่น TPU, NPU, AMD GPU เป็นต้น
- การพอร์ตโค้ด PyTorch ที่เขียนมาสำหรับ 1 โหนดไปสู่ multi-node ทำได้ยาก
- multi-node มักต้องใช้เวลาพัฒนาหลายสิบชั่วโมงและต้องแก้โค้ดอย่างมาก
- แนวทางแบบเน้นคอมไพเลอร์ของ JAX มีข้อได้เปรียบในจุดนี้
- XLA จัดการการสลับข้าม device backend และทำให้โค้ดทำงานได้ดีบน GPU/TPU/multi-node/multi-slice โดยแก้โค้ดน้อยมาก
- ทำให้ผู้ผลิตฮาร์ดแวร์รองรับอุปกรณ์ได้ง่ายขึ้น และทำให้การสลับข้ามอุปกรณ์ง่ายขึ้น
- ไม่ใช่ทุกคนจะเข้าถึงฮาร์ดแวร์แบบเดียวกันได้ ดังนั้นโค้ดเบสที่พกพาได้ข้ามฮาร์ดแวร์หลายประเภทจึงอาจเป็นก้าวเล็ก ๆ ที่ทำให้ deep learning เข้าถึงผู้เริ่มต้น/ระดับกลางได้ง่ายขึ้น
- การสเกลอัตโนมัติ
- โค้ดเบสที่สามารถ auto-scale ได้ดีด้วยตัวเองช่วยเรื่องการทำซ้ำผลลัพธ์อย่างมาก
- ในกรณีอุดมคติ มันควรเกิดขึ้นโดยอัตโนมัติด้วยการแก้โค้ดให้น้อยที่สุด โดยไม่ขึ้นกับขอบเขตของเครือข่าย
- JAX ทำสิ่งนี้ได้ดี
- เวลาเขียนโค้ด JAX ไม่จำเป็นต้องระบุ communication primitive หรือวาง
torch.distributed.barrier()ไว้ทุกที่ - XLA จะใส่สิ่งเหล่านี้ให้อัตโนมัติโดยพิจารณาจากฮาร์ดแวร์ที่มีอยู่
- อุปกรณ์ทั้งหมดที่ JAX ตรวจพบจะถูกใช้งานโดยอัตโนมัติ ไม่ว่าจะเป็นเรื่องเครือข่าย topology การตั้งค่า หรืออื่น ๆ
- มันจะซิงก์และเตรียมการคำนวณโดยอัตโนมัติ รวมทั้งใช้ optimization pass เพื่อเพิ่มการรันแบบอะซิงโครนัสของ kernel ให้สูงสุดและลด latency ให้น้อยที่สุด
- สิ่งเดียวที่มนุษย์ต้องทำคือระบุ sharding ของเทนเซอร์ที่ต้องการกระจายไปยังอุปกรณ์ เช่น batch dimension ของอาร์เรย์อินพุต
- ด้วยแนวทางของ XLA ที่ว่า "การคำนวณจะเป็นไปตาม sharding" มันจึงจัดการส่วนที่เหลือให้โดยอัตโนมัติ
- สิ่งนี้ทำให้สามารถรันการทดลองที่ผ่านการตรวจสอบในสเกลที่เหมาะสมได้ง่ายแม้ในฐานะงานอดิเรก เพื่อทดลองและอาจทำซ้ำผลลัพธ์ได้
- มันอาจช่วยให้ค้นพบไอเดียที่ถูกลืมได้ง่ายขึ้น และส่งเสริมการทดลองแบบนั้น เพราะสามารถทดสอบเป็นฟังก์ชันในสเกลที่ใหญ่ขึ้นได้ด้วยความพยายามน้อยมาก
ข้อเสียของ JAX
- โครงสร้างธรรมาภิบาล
- ปัจจุบัน XLA อยู่ภายใต้ธรรมาภิบาลของ TensorFlow
- เคยมีการพูดคุยเรื่องการจัดตั้งองค์กรแยกต่างหากแบบเดียวกับ PyTorch แต่ยังไม่มีความพยายามที่เป็นรูปธรรมมากนัก
- ความเชื่อมั่นต่อ Google ไม่ได้สูงนัก เพราะมีชื่อเสียงเรื่องยกเลิกผลิตภัณฑ์ที่ไม่เป็นที่นิยม
- แม้ JAX จะเป็นโปรเจกต์ของ DeepMind ในเชิงเทคนิค และมีความสำคัญอย่างยิ่งต่อการผลักดัน AI โดยรวมของ Google แต่ดูแล้วการมีโครงสร้างระยะยาวจะเป็นประโยชน์อย่างมากต่อ ecosystem ทั้งหมด
- องค์กรกำกับดูแลแยกต่างหากจะช่วยกำหนดทิศทางการพัฒนาโปรเจกต์
- สิ่งนี้จะให้โครงสร้างที่ชัดเจนและแยกออกจากระบบราชการอันเลื่องชื่อของ Google ช่วยหลีกเลี่ยงปัญหาหลายอย่างในคราวเดียว
- ไม่ได้หมายความว่า JAX จำเป็นต้องมีโครงสร้างทางการเช่นนี้เสมอไป แต่ก็คงดีหากมีหลักประกันว่าการพัฒนา JAX จะดำเนินต่อไปอีกนาน ไม่ว่าผู้บริหารระดับสูงของ Google จะตัดสินใจอย่างไร
- สิ่งนี้จะช่วยเรื่องการยอมรับโดยบริษัทและสถาบันวิจัยขนาดใหญ่ได้อย่างชัดเจน เพราะคนเหล่านี้มักลังเลที่จะทุ่มทรัพยากรเพื่อผสานเครื่องมือที่วันหนึ่งอาจไม่มีผู้ดูแลอีกต่อไป
- การเปลี่ยนผ่านของ XLA ไปสู่โอเพนซอร์ส
- เป็นเวลานาน XLA เป็นโปรเจกต์แบบปิดซอร์ส
- แต่ก็มีความพยายามทำให้มันเป็นโอเพนซอร์ส และตอนนี้ OpenXLA แสดงประสิทธิภาพที่ดีกว่า XLA build ภายในอย่างมาก
- อย่างไรก็ตาม เอกสารเกี่ยวกับภายในของ XLA ยังขาดแคลน
- ทรัพยากรส่วนใหญ่เป็นเพียงไลฟ์ทอล์กและบางครั้งก็เป็นงานวิจัย ซึ่งมักล้าสมัย
- หากมีโรดแมปที่เข้าถึงได้สาธารณะสำหรับฟีเจอร์ที่วางแผนไว้ ผู้คนก็จะติดตามความคืบหน้าและมีส่วนร่วมกับสิ่งที่น่าสนใจเป็นพิเศษได้ง่ายขึ้น
- การมีมินิบล็อกโพสต์สไตล์ Edward Yang ที่วิเคราะห์แต่ละขั้นของสแต็กคอมไพเลอร์ XLA และอธิบายรายละเอียด จะช่วยให้ผู้ปฏิบัติงานประเมินได้ดีขึ้นว่า XLA ทำอะไรได้และทำอะไรไม่ได้
- เข้าใจดีว่าสิ่งนี้ใช้ทรัพยากรมากและอาจสื่อสารผ่านช่องทางอื่นได้ดีกว่า แต่ผู้คนมักเชื่อถือเครื่องมือมากขึ้นเมื่อเข้าใจมัน และเชื่อว่าสิ่งนี้จะสร้างผลกระทบเชิงบวกเป็นลูกโซ่ต่อ ecosystem ทั้งหมด เป็นประโยชน์ต่อทุกฝ่าย
- การบูรณาการ ecosystem
flaxคือจุดน่าปวดหัวของ ecosystem JAX- มันมี API ที่ไม่ intuitive ใช้ไวยากรณ์ที่กระชับเกินไป และสำหรับผู้เริ่มต้นที่ย้ายมาจาก PyTorch มันคือขุมนรกอย่างแท้จริง
- ขอแนะนำให้ใช้
equinox - แม้ทีมพัฒนาจะพยายามแก้ข้อเสียของ
flaxแต่ท้ายที่สุดแล้วก็เป็นการเสียเวลา - ถ้าต้องการ API สไตล์ equinox ก็ควรใช้
equinoxไปเลย - ไม่มีหลายอย่างที่
flaxทำได้ดีกว่าเป็นพิเศษ และก็ไม่ยากที่จะทำซ้ำด้วยequinox - ปัจจุบัน ecosystem ของ JAX จำนวนมากถูกออกแบบโดยมี
flaxเป็นศูนย์กลาง equinoxสามารถทำงานร่วมกับทุกไลบรารีได้ เพราะมันเชื่อมต่อกับ PyTree โดยพื้นฐาน แม้อาจต้องใช้eqx.partitionและ filter อยู่บ้าง- อยากเปลี่ยนสภาพที่เป็นอยู่
equinoxควรได้รับการรองรับระดับ first-class ในทุกที่ - นี่อาจเป็นความเห็นที่ถกเถียงกันได้ แต่ทั้งหมดนี้คือความผิดพลาดแบบ sunk cost fallacy คลาสสิก
equinoxทำงานได้ดีกว่าในแบบที่เฟรมเวิร์ก JAX ควรจะเป็นมาตลอด- ถ้าเทียบ
equinoxกับflaxตามที่สรุปไว้ในเอกสารของ equinox แล้วequinoxดีกว่า - เป็นเรื่องดีที่ผู้ดูแล ecosystem ของ JAX เริ่มตระหนักถึงความนิยมของ
equinoxและปรับตัวตาม แต่ก็หวังว่า Google และทีมflaxจะให้การสนับสนุนอย่างเป็นทางการมากกว่านี้ - หากอยากลองใช้ JAX ขอแนะนำให้ใช้
equinox
- ขอบคมที่ต้องระวัง
- ด้วยเหตุผลด้านการออกแบบ API และข้อจำกัดของ XLA ทำให้ JAX มี "ขอบคม" ที่ต้องระวัง
- มีการอธิบายเรื่องนี้ไว้อย่างกระชับมากในเอกสารที่เขียนได้ดี
- แนะนำให้อ่านอย่างน้อยหนึ่งครั้งก่อนใช้ JAX
- การ RTFM ย่อมช่วยประหยัดเวลาและพลังงานได้มากเสมอ
บทสรุป
- โพสต์บล็อกนี้มีไว้เพื่อแก้ความเชื่อที่ถูกพูดซ้ำ ๆ ว่า PyTorch เหมาะที่สุดกับเวิร์กโหลดงานวิจัยจริง โดยเฉพาะบน GPU ซึ่งเป็น Myth ที่พบได้บ่อย มันไม่จริงอีกต่อไปแล้ว
- อันที่จริง ผู้เขียนถึงขั้นยืนยันอย่างสุดโต่งว่าการพอร์ตโค้ด PyTorch ทั้งหมดไปเป็น JAX จะเป็นประโยชน์อย่างมหาศาลต่อทั้งวงการ
- การขนานอัตโนมัติ ความสามารถในการทำซ้ำผลลัพธ์ API เชิงฟังก์ชันที่สะอาด ฯลฯ ไม่ใช่ฟีเจอร์เล็กน้อย และจะช่วยโค้ดเบสงานวิจัยจำนวนมากได้มาก
- หากคุณอยากทำให้วงการนี้ดีขึ้นอีกสักนิด ลองพิจารณาเขียนโค้ดเบสของคุณใหม่ด้วย JAX
8 ความคิดเห็น
โลกยังคงหมุนต่อไป 555
เปรียบเทียบ PyTorch กับ TensorFlow ในปี 2022
ผมจะยึด
torchกับonnxต่อไปครับบทความที่นักศึกษาปริญญาตรีเขียน.. โห
ถ้าไม่มี Huggingface นะ PyTorch นี่จบจริง 555
JAX จงเจริญ! เพิ่งได้ลองใช้เมื่อไม่นานมานี้ และชอบ NNX API มากครับ
ปัญหาใหญ่ที่สุดของ JAX คือมันมาจาก Google นี่แหละ Google ขึ้นชื่อมากเรื่องการทิ้งโอเพนซอร์ส (Tflite, android things, dart, angular, bazel ฯลฯ) แม้แต่ tensorflow เองก็เริ่มอัปเดตไม่ค่อยดีตั้งแต่ช่วงหนึ่งเป็นต้นมา ในทางกลับกัน torch เริ่มต้นจาก Facebook ที่ดูแลโอเพนซอร์สขนาดใหญ่อยู่แล้ว จึงบริหารจัดการได้ดีมาก และตอนนี้ก็อยู่ภายใต้การดูแลของมูลนิธิ torch แล้ว จุดอ่อนของ torch มีส่วนที่จริงแน่นอน แต่ในแง่ของว่าใครจะเป็นคนดูแลโอเพนซอร์สนี้อย่างยั่งยืน JAX ดูเหมือนจะเริ่มต้นมาพร้อมความเสี่ยงก้อนใหญ่แล้ว
อย่างน้อย Dart ก็ดูเหมือนจะยังไปได้ดีกับ Flutter ไปอีกสักพักหนึ่ง
Facebook ดูเหมือนจะยังคงมีความรับผิดชอบต่อเทคโนโลยีสแต็กที่ตัวเองใช้งานอยู่ เช่น React, Django และยังคงมีส่วนร่วมอย่างต่อเนื่อง แต่ Google ดูเหมือนว่าถ้าอะไรเริ่มตกยุคขึ้นมานิดหน่อยก็จะทิ้งเหมือนรองเท้าเก่า...