14 คะแนน โดย xguru 2024-08-19 | 8 ความคิดเห็น | แชร์ทาง WhatsApp
  • เหตุผลที่ PyTorch ก่อให้เกิดการสูญเสียผลิตภาพและทำให้เสียเวลาในการพัฒนา ไม่ใช่เพราะ "ตัวเฟรมเวิร์กแย่" แต่เป็นเพราะมันไม่ได้ถูกออกแบบมาให้เหมาะกับยูสเคสที่กำลังถูกนำไปใช้ในปัจจุบัน

ปรัชญาของ PyTorch

  • ปรัชญาของ PyTorch คือความเป็นไดนามิก ดีบักง่าย และเป็น Pythonic
  • ขณะที่ TensorFlow 1.x ตั้งใจจะเป็นเฟรมเวิร์กแบบสแตติกแต่มีประสิทธิภาพสูง โดยพึ่งพาคอมไพเลอร์ XLA อย่างหนัก
  • นักพัฒนา TensorFlow ตระหนักว่าชุมชนไม่ชอบ API ของ 1.x จึงตัดสินใจใช้ Keras เป็นอินเทอร์เฟซหลัก และลดบทบาทของคอมไพเลอร์ XLA ลง
  • PyTorch ยังคงยึดรากเดิมของตนไว้ และต่างจากแนวทางแบบสแตติกและประเมินผลล่าช้าของ TensorFlow โดยใช้แนวทาง "eager execution" ที่มีความไดนามิกมากกว่า ซึ่ง torch.Tensor จะถูกประเมินผลทันที
  • เมื่อแนวทางนี้ประสบความสำเร็จ งานวิจัยจำนวนมากก็ย้ายมาใช้ PyTorch
  • เมื่อ GPT-3 ปรากฏตัวในปี 2021 ประสิทธิภาพและความสามารถในการสเกลกลายเป็นประเด็นสำคัญ
  • PyTorch ตอบสนองความต้องการเหล่านี้ได้ดีในระดับหนึ่ง แต่เพราะมันไม่ได้ถูกออกแบบโดยคำนึงถึงปรัชญานี้ หนี้ทางเทคนิคจึงค่อย ๆ สะสมและฐานรากเริ่มสั่นคลอน
  • นักพัฒนา PyTorch ไม่ต้องการจุดประนีประนอมใด ๆ และเลือกจะเดินสองเส้นทางพร้อมกัน
    • ใช้คอมไพเลอร์ XLA เป็นแบ็กเอนด์พื้นฐานที่มีประสิทธิภาพและเสถียรภาพสูง
    • สร้างสแต็ก torch.compile เพื่อให้ผู้ใช้มีอิสระในการเรียกใช้คอมไพเลอร์เมื่อจำเป็น
  • การไม่มีแผนกลยุทธ์ระยะยาวเป็นปัญหาร้ายแรง
  • PyTorch ไม่อยากยึดมั่นกับปรัชญาแบบเน้นคอมไพเลอร์เป็นศูนย์กลาง (เช่น JAX) แต่ก็มองไม่เห็นทางเลือกที่ดี
  • แล้วคู่แข่งแก้ปัญหานี้อย่างไร?

การพัฒนาแบบอิงคอมไพเลอร์ของ JAX

  • JAX ใช้ประโยชน์จาก XLA ซึ่งเป็นสแต็กคอมไพเลอร์อันทรงพลังของ TensorFlow
  • XLA เป็นคอมไพเลอร์ที่ทรงพลัง แต่ทั้งหมดถูกนามธรรมซ่อนไว้จากผู้ใช้ปลายทาง
  • หากฟังก์ชันเป็น pure ก็สามารถใช้ดีคอเรเตอร์ @jax.jit เพื่อทำ JIT compile และทำให้มันใช้งานบน XLA ได้
  • XLA จะจัดการทุกอย่างเบื้องหลัง ทั้งการตรวจสอบว่ากราฟที่สร้างขึ้นถูกต้องหรือไม่, GSPMD partitioner ที่จัดการการขนานอัตโนมัติด้วย sharding ใน JAX, การเพิ่มประสิทธิภาพกราฟ, การ fusion ของ operator และ kernel, การจัดตารางเพื่อซ่อน latency, การ overlap การสื่อสารแบบอะซิงโครนัส, การสร้างโค้ดไปยังแบ็กเอนด์อื่นอย่าง triton เป็นต้น
  • ตราบใดที่ปฏิบัติตามข้อจำกัดของ JAX, XLA จะจัดการให้โดยอัตโนมัติ
  • ตัวอย่างเช่น เวลา parallelize ไม่จำเป็นต้องมี communication primitive อย่าง torch.distributed.barrier()
  • การรองรับ DDP ทำได้ด้วยโค้ดที่เรียบง่าย
  • แนวทางของ XLA คือการคำนวณจะเป็นไปตาม sharding ดังนั้นหากอาร์เรย์อินพุตถูก shard ตามแกนใดแกนหนึ่ง XLA ก็จะจัดการการคำนวณย่อยต่อจากนั้นให้โดยอัตโนมัติ
  • แนวคิดเรื่อง "การพัฒนาแบบอิงคอมไพเลอร์" คล้ายกับวิธีการทำงานของคอมไพเลอร์ Rust
  • ข้อจำกัดของ PyTorch
    • ไม่พอใจกับการที่นักพัฒนา PyTorch เลือกผสานและพึ่งพาสแต็กคอมไพเลอร์สำหรับฟีเจอร์ใหม่ แทนที่จะรักษาปรัชญาแกนหลักเรื่องความยืดหยุ่นและอิสระไว้
    • ตามโรดแมปอย่างเป็นทางการของ PyTorch 2.x มีการระบุแผนระยะยาวอย่างชัดเจนว่าจะผสาน XLA เข้ากับ Torch อย่างสมบูรณ์
    • นี่เป็นไอเดียที่แย่มาก มันเหมือนกับการบอกว่าการยัดโค้ด C++ เข้าไปในคอมไพเลอร์ Rust จะให้ประสบการณ์ที่ดีกว่าการใช้ Rust เอง
    • Torch ไม่ได้ถูกออกแบบโดยมี XLA เป็นศูนย์กลาง ต่างจาก JAX
    • หาก PyTorch ตัดสินใจใช้สแต็กคอมไพเลอร์ที่อิง XLA แล้ว เฟรมเวิร์กในอุดมคติไม่ควรเป็นสิ่งที่ถูกออกแบบและสร้างขึ้นมาโดยเฉพาะรอบสิ่งนั้นหรือ?
    • ต่อให้ PyTorch จะเดินตามแนวทาง "multi-backend" ที่เลือกคอมไพเลอร์แบ็กเอนด์ได้ตามต้องการ ก็จะยิ่งทำให้ปัญหาความแตกเป็นเสี่ยงแย่ลง และอาจทำให้ API พังยับขณะพยายามเคารพข้อจำกัดของสแต็กคอมไพเลอร์ทุกตัวหรือไม่?
    • ใครก็ตามที่เคยใช้ Torch/XLA บน TPU ต่างก็มีอาการ PTSD หนักกันทั้งนั้น

Multi-Backend ล้มเหลวแล้ว

  • PyTorch พยายามทำทุกอย่างพร้อมกันและล้มเหลวอย่างน่าเวทนา
  • การตัดสินใจออกแบบแบบ "multi-backend" ทำให้ปัญหานี้แย่ลงแบบทวีคูณ
  • ในทางทฤษฎีมันฟังดูเหมือนจะเลือกสแต็กที่ต้องการได้ แต่ในทางปฏิบัติคือความยุ่งเหยิงของ traceback ที่เข้าใจยากและปัญหาความไม่เข้ากันที่พันกันยุ่ง
  • ข้อจำกัดข้ามแบ็กเอนด์และการปะทะกับ API ของ PyTorch
    • สิ่งที่ยากไม่ใช่แค่การทำให้แบ็กเอนด์เหล่านี้ทำงาน แต่ข้อจำกัดที่แบ็กเอนด์เหล่านี้คาดหวังกลับไม่เข้ากับ API แบบยืดหยุ่นและ Pythonic ของ PyTorch
    • มี trade-off ระหว่างการรักษาความสม่ำเสมอของ API กับการทำตามข้อจำกัดของแบ็กเอนด์
    • ผลลัพธ์คือเหล่านักพัฒนาพยายามพึ่งการสร้างโค้ดมากขึ้น แทนที่จะผสาน/ยึดกับแบ็กเอนด์ใดแบ็กเอนด์หนึ่งจริง ๆ
  • การไร้กลยุทธ์ของ PyTorch
    • เพราะ PyTorch ปฏิเสธที่จะยอมรับ trade-off ที่มีความหมาย การตัดสินใจทุกอย่างจึงให้ความรู้สึกเหมือนการประนีประนอม
    • ไม่มีทั้งความสม่ำเสมอและไม่มีกลยุทธ์โดยรวม
    • ท้ายที่สุดสิ่งนี้สร้างความหงุดหงิดให้ผู้ใช้มาก และทำให้มันดูเหมือนกองรวมของฟีเจอร์ที่ไม่เข้ากัน
    • ไม่มีวิธีใดฆ่า ecosystem ได้เร็วไปกว่านี้อีกแล้ว
  • ทำไมไม่ควรเดินตามแนวทางของ JAX
    • PyTorch ไม่ควรเดินตามแนวทาง "คอมไพเลอร์และแบ็กเอนด์แบบรวมศูนย์" ของ JAX
    • เพราะ JAX ถูกออกแบบมาอย่างชัดเจนให้ทำงานร่วมกับ XLA
    • การแทนที่ฟรอนต์เอนด์ของ PyTorch ด้วยของ JAX ไม่อาจเป็นกลยุทธ์ได้
    • แทบเป็นไปไม่ได้ที่จะออกแบบ API บน XLA ที่ดีกว่า JAX
    • ไม่ได้ตำหนินักพัฒนาที่อยากลองไอเดียใหม่และแตกต่าง
    • แต่ถ้า PyTorch อยากอยู่รอดในระยะยาว ก็ควรให้ความสำคัญกับการเสริมฐานรากมากกว่าการปล่อยฟีเจอร์ใหม่สุดเท่ที่พังทันทีเมื่อออกนอกเงื่อนไขสอนใช้งานในอุดมคติ

ความแตกเป็นเสี่ยงของ PyTorch และการเขียนโปรแกรมเชิงฟังก์ชันของ JAX

  • API เชิงฟังก์ชันของ JAX
    • ฟังก์ชันใน JAX ต้องเป็น pure กล่าวคือห้ามมีผลข้างเคียงระดับโกลบอล
    • เช่นเดียวกับฟังก์ชันทางคณิตศาสตร์ เมื่อได้รับข้อมูลชุดเดิมก็ต้องคืนผลลัพธ์เดิมเสมอ ไม่ว่าคอนเท็กซ์ของการรันจะเป็นอย่างไร
    • ด้วยปรัชญาการออกแบบนี้ ฟังก์ชันของ JAX จึงประกอบรวมกันได้และทำงานร่วมกันได้ดี
    • ความซับซ้อนในการพัฒนาลดลง และฟังก์ชันถูกนิยามด้วย signature เฉพาะและงานเชิงรูปธรรมที่กำหนดไว้อย่างชัดเจน
    • หาก type ถูกต้อง ก็รับประกันได้ว่าฟังก์ชันจะทำงานได้ทันที
    • สิ่งนี้เหมาะกับประเภทงานที่จำเป็นในการคำนวณเชิงวิทยาศาสตร์ โดยเฉพาะ deep learning
  • ตัวอย่าง API ของ optax
    • ด้วยแนวทางเชิงฟังก์ชัน optax จึงมีสิ่งที่เรียกว่า "chain"
    • ซึ่งประกอบด้วยหลายฟังก์ชันที่ถูกนำไปใช้กับ gradient ตามลำดับ
    • องค์ประกอบพื้นฐานคือ GradientTransformation
    • ทำให้ได้ API ที่ทรงพลังแต่ยังแสดงออกได้ดี
    • ตัวอย่างเช่น การ clip gradient, การทำ EMA ของ gradient หรือการรวม optimizer หลายตัวเข้าด้วยกัน กลายเป็นเรื่องง่ายมาก
  • ข้อดีของการออกแบบเชิงฟังก์ชัน
    • ผลลัพธ์ที่ยอดเยี่ยมอีกอย่างของการออกแบบเชิงฟังก์ชันคือ vmap
    • มันย่อมาจาก vectorized map และอธิบายหน้าที่ของมันได้ตรงตัว
    • สามารถ map ได้กับทุกอย่าง และตราบใดที่เป็น vmap XLA ก็จะ fusion และ optimize ให้โดยอัตโนมัติ
    • เวลาเขียนฟังก์ชันไม่จำเป็นต้องคิดถึง batch dimension
    • แค่ vmap โค้ดทั้งหมดก็พอ
    • นั่นหมายความว่าต้องทำงานพวก ein-* น้อยลง
    • การทำความเข้าใจการจัดการเทนเซอร์ 2D/3D จึงเป็นธรรมชาติกว่าและอ่านง่ายกว่ามาก
    • เพราะแค่แยกองค์ประกอบแต่ละส่วนออกมาเพื่อใช้เหตุผล ก็ทำให้เขียนโค้ดซับซ้อนที่ทำงานถูกต้องได้ง่ายขึ้น
    • ตราบใดที่เคารพข้อจำกัดเรื่องความเป็น pure และมี signature ที่ถูกต้อง ก็จะได้ประโยชน์อื่นทั้งหมดอย่างความสามารถในการประกอบรวมตามมา
  • ปัญหาของ ecosystem ของ PyTorch
    • ใน torch ไม่ว่าคุณจะใช้สแต็กแบบไหน (FSDP + multi-node + torch.compile เป็นต้น) ก็มีโอกาสที่บางอย่างจะพังเสมอ
    • หลายสิ่งต้องทำงานร่วมกันอย่างถูกต้อง และหากองค์ประกอบใดองค์ประกอบหนึ่งล้มเหลว คุณก็ต้องดีบักกันถึงตี 3
    • เนื่องจากไม่สามารถทดสอบทุกชุดผสมของฟีเจอร์หลายสิบอย่างที่ PyTorch มีได้ จึงย่อมมีบั๊กที่ไม่ถูกค้นพบระหว่างการพัฒนาอยู่เสมอ
    • แทบเป็นไปไม่ได้ที่จะเขียนโค้ดที่ทำงานได้ดีโดยไม่ทุ่มแรงอย่างมาก
    • ecosystem ของ torch จึงใหญ่เทอะทะมากและเต็มไปด้วยบั๊ก
    • เพราะไม่มี abstraction ร่วมกัน ไลบรารีและเฟรมเวิร์กใหม่ ๆ จึงเกิดขึ้นโดยไม่ได้ถูกออกแบบมาให้เชื่อมต่อกับ "โซลูชัน" อื่น
    • จากนั้นก็เสื่อมสภาพอย่างรวดเร็วกลายเป็นความโกลาหลของ dependency และ requirements.txt
    • 70-80% ของ GitHub issues หรือการถกเถียงในฟอรัม เกิดขึ้นเพียงเพราะมี error ระหว่างไลบรารีที่ต่างกัน
    • แทบไม่มีทางแก้ปัญหานี้ได้
  • การไม่มีทางออก
    • นี่คือปัญหาของ OOP และการออกแบบ
    • มีความคิดว่าอ็อบเจ็กต์พื้นฐานที่มีความเป็น PyTorch อย่าง PyTree อาจช่วยสร้างฐานร่วมสำหรับ abstraction ได้
    • ก็ไม่อาจหันไปใช้พาราไดม์การเขียนโปรแกรมเชิงฟังก์ชันได้เช่นกัน
    • เพราะถ้าทำเช่นนั้น มันจะค่อย ๆ กลายเป็น JAX เวอร์ชันที่ประสิทธิภาพแย่กว่า พร้อมกับทำลาย backward compatibility ของโค้ดเบส torch ที่มีอยู่ทั้งหมด
    • PyTorch ดูเหมือนจะพังยับในประเด็นนี้อย่างสิ้นเชิง

ความได้เปรียบด้านการทำซ้ำผลลัพธ์ของ JAX

  • การจัดการ seed
    • การจัดการ seed ของ PyTorch ไม่ได้ดีในอุดมคติ
    • โดยทั่วไปต้องรันโค้ดหลายบรรทัด
    • ลืมหรือตั้งค่าผิดได้ง่าย
    • JAX บังคับให้สร้าง key อย่างชัดเจนแล้วส่งเข้าไปยังทุกฟังก์ชันที่ต้องใช้ความสุ่ม
    • แนวทางนี้ขจัดปัญหาได้หมดจด เพราะ RNG จะถูก seed แบบสแตติกเสมอ
    • JAX มี NumPy เวอร์ชันของตัวเอง (jax.numpy) จึงไม่จำเป็นต้องตั้ง seed แยกต่างหาก
    • การตัดสินใจด้าน QoL เล็ก ๆ แบบนี้สามารถทำให้ประสบการณ์ผู้ใช้ของทั้งเฟรมเวิร์กดีขึ้นอย่างมาก
  • ความสามารถในการพกพา
    • หนึ่งในปัญหาใหญ่ที่สุดเวลาทำงานกับโค้ดเบส PyTorch คือการขาดความสามารถในการพกพา
    • โค้ดเบสที่เขียนมาสำหรับ CUDA/GPU มักทำงานได้ไม่ดีเมื่อรันบนฮาร์ดแวร์ที่ไม่ใช่ Nvidia เช่น TPU, NPU, AMD GPU เป็นต้น
    • การพอร์ตโค้ด PyTorch ที่เขียนมาสำหรับ 1 โหนดไปสู่ multi-node ทำได้ยาก
    • multi-node มักต้องใช้เวลาพัฒนาหลายสิบชั่วโมงและต้องแก้โค้ดอย่างมาก
    • แนวทางแบบเน้นคอมไพเลอร์ของ JAX มีข้อได้เปรียบในจุดนี้
    • XLA จัดการการสลับข้าม device backend และทำให้โค้ดทำงานได้ดีบน GPU/TPU/multi-node/multi-slice โดยแก้โค้ดน้อยมาก
    • ทำให้ผู้ผลิตฮาร์ดแวร์รองรับอุปกรณ์ได้ง่ายขึ้น และทำให้การสลับข้ามอุปกรณ์ง่ายขึ้น
    • ไม่ใช่ทุกคนจะเข้าถึงฮาร์ดแวร์แบบเดียวกันได้ ดังนั้นโค้ดเบสที่พกพาได้ข้ามฮาร์ดแวร์หลายประเภทจึงอาจเป็นก้าวเล็ก ๆ ที่ทำให้ deep learning เข้าถึงผู้เริ่มต้น/ระดับกลางได้ง่ายขึ้น
  • การสเกลอัตโนมัติ
    • โค้ดเบสที่สามารถ auto-scale ได้ดีด้วยตัวเองช่วยเรื่องการทำซ้ำผลลัพธ์อย่างมาก
    • ในกรณีอุดมคติ มันควรเกิดขึ้นโดยอัตโนมัติด้วยการแก้โค้ดให้น้อยที่สุด โดยไม่ขึ้นกับขอบเขตของเครือข่าย
    • JAX ทำสิ่งนี้ได้ดี
    • เวลาเขียนโค้ด JAX ไม่จำเป็นต้องระบุ communication primitive หรือวาง torch.distributed.barrier() ไว้ทุกที่
    • XLA จะใส่สิ่งเหล่านี้ให้อัตโนมัติโดยพิจารณาจากฮาร์ดแวร์ที่มีอยู่
    • อุปกรณ์ทั้งหมดที่ JAX ตรวจพบจะถูกใช้งานโดยอัตโนมัติ ไม่ว่าจะเป็นเรื่องเครือข่าย topology การตั้งค่า หรืออื่น ๆ
    • มันจะซิงก์และเตรียมการคำนวณโดยอัตโนมัติ รวมทั้งใช้ optimization pass เพื่อเพิ่มการรันแบบอะซิงโครนัสของ kernel ให้สูงสุดและลด latency ให้น้อยที่สุด
    • สิ่งเดียวที่มนุษย์ต้องทำคือระบุ sharding ของเทนเซอร์ที่ต้องการกระจายไปยังอุปกรณ์ เช่น batch dimension ของอาร์เรย์อินพุต
    • ด้วยแนวทางของ XLA ที่ว่า "การคำนวณจะเป็นไปตาม sharding" มันจึงจัดการส่วนที่เหลือให้โดยอัตโนมัติ
    • สิ่งนี้ทำให้สามารถรันการทดลองที่ผ่านการตรวจสอบในสเกลที่เหมาะสมได้ง่ายแม้ในฐานะงานอดิเรก เพื่อทดลองและอาจทำซ้ำผลลัพธ์ได้
    • มันอาจช่วยให้ค้นพบไอเดียที่ถูกลืมได้ง่ายขึ้น และส่งเสริมการทดลองแบบนั้น เพราะสามารถทดสอบเป็นฟังก์ชันในสเกลที่ใหญ่ขึ้นได้ด้วยความพยายามน้อยมาก

ข้อเสียของ JAX

  • โครงสร้างธรรมาภิบาล
    • ปัจจุบัน XLA อยู่ภายใต้ธรรมาภิบาลของ TensorFlow
    • เคยมีการพูดคุยเรื่องการจัดตั้งองค์กรแยกต่างหากแบบเดียวกับ PyTorch แต่ยังไม่มีความพยายามที่เป็นรูปธรรมมากนัก
    • ความเชื่อมั่นต่อ Google ไม่ได้สูงนัก เพราะมีชื่อเสียงเรื่องยกเลิกผลิตภัณฑ์ที่ไม่เป็นที่นิยม
    • แม้ JAX จะเป็นโปรเจกต์ของ DeepMind ในเชิงเทคนิค และมีความสำคัญอย่างยิ่งต่อการผลักดัน AI โดยรวมของ Google แต่ดูแล้วการมีโครงสร้างระยะยาวจะเป็นประโยชน์อย่างมากต่อ ecosystem ทั้งหมด
    • องค์กรกำกับดูแลแยกต่างหากจะช่วยกำหนดทิศทางการพัฒนาโปรเจกต์
    • สิ่งนี้จะให้โครงสร้างที่ชัดเจนและแยกออกจากระบบราชการอันเลื่องชื่อของ Google ช่วยหลีกเลี่ยงปัญหาหลายอย่างในคราวเดียว
    • ไม่ได้หมายความว่า JAX จำเป็นต้องมีโครงสร้างทางการเช่นนี้เสมอไป แต่ก็คงดีหากมีหลักประกันว่าการพัฒนา JAX จะดำเนินต่อไปอีกนาน ไม่ว่าผู้บริหารระดับสูงของ Google จะตัดสินใจอย่างไร
    • สิ่งนี้จะช่วยเรื่องการยอมรับโดยบริษัทและสถาบันวิจัยขนาดใหญ่ได้อย่างชัดเจน เพราะคนเหล่านี้มักลังเลที่จะทุ่มทรัพยากรเพื่อผสานเครื่องมือที่วันหนึ่งอาจไม่มีผู้ดูแลอีกต่อไป
  • การเปลี่ยนผ่านของ XLA ไปสู่โอเพนซอร์ส
    • เป็นเวลานาน XLA เป็นโปรเจกต์แบบปิดซอร์ส
    • แต่ก็มีความพยายามทำให้มันเป็นโอเพนซอร์ส และตอนนี้ OpenXLA แสดงประสิทธิภาพที่ดีกว่า XLA build ภายในอย่างมาก
    • อย่างไรก็ตาม เอกสารเกี่ยวกับภายในของ XLA ยังขาดแคลน
    • ทรัพยากรส่วนใหญ่เป็นเพียงไลฟ์ทอล์กและบางครั้งก็เป็นงานวิจัย ซึ่งมักล้าสมัย
    • หากมีโรดแมปที่เข้าถึงได้สาธารณะสำหรับฟีเจอร์ที่วางแผนไว้ ผู้คนก็จะติดตามความคืบหน้าและมีส่วนร่วมกับสิ่งที่น่าสนใจเป็นพิเศษได้ง่ายขึ้น
    • การมีมินิบล็อกโพสต์สไตล์ Edward Yang ที่วิเคราะห์แต่ละขั้นของสแต็กคอมไพเลอร์ XLA และอธิบายรายละเอียด จะช่วยให้ผู้ปฏิบัติงานประเมินได้ดีขึ้นว่า XLA ทำอะไรได้และทำอะไรไม่ได้
    • เข้าใจดีว่าสิ่งนี้ใช้ทรัพยากรมากและอาจสื่อสารผ่านช่องทางอื่นได้ดีกว่า แต่ผู้คนมักเชื่อถือเครื่องมือมากขึ้นเมื่อเข้าใจมัน และเชื่อว่าสิ่งนี้จะสร้างผลกระทบเชิงบวกเป็นลูกโซ่ต่อ ecosystem ทั้งหมด เป็นประโยชน์ต่อทุกฝ่าย
  • การบูรณาการ ecosystem
    • flax คือจุดน่าปวดหัวของ ecosystem JAX
    • มันมี API ที่ไม่ intuitive ใช้ไวยากรณ์ที่กระชับเกินไป และสำหรับผู้เริ่มต้นที่ย้ายมาจาก PyTorch มันคือขุมนรกอย่างแท้จริง
    • ขอแนะนำให้ใช้ equinox
    • แม้ทีมพัฒนาจะพยายามแก้ข้อเสียของ flax แต่ท้ายที่สุดแล้วก็เป็นการเสียเวลา
    • ถ้าต้องการ API สไตล์ equinox ก็ควรใช้ equinox ไปเลย
    • ไม่มีหลายอย่างที่ flax ทำได้ดีกว่าเป็นพิเศษ และก็ไม่ยากที่จะทำซ้ำด้วย equinox
    • ปัจจุบัน ecosystem ของ JAX จำนวนมากถูกออกแบบโดยมี flax เป็นศูนย์กลาง
    • equinox สามารถทำงานร่วมกับทุกไลบรารีได้ เพราะมันเชื่อมต่อกับ PyTree โดยพื้นฐาน แม้อาจต้องใช้ eqx.partition และ filter อยู่บ้าง
    • อยากเปลี่ยนสภาพที่เป็นอยู่ equinox ควรได้รับการรองรับระดับ first-class ในทุกที่
    • นี่อาจเป็นความเห็นที่ถกเถียงกันได้ แต่ทั้งหมดนี้คือความผิดพลาดแบบ sunk cost fallacy คลาสสิก
    • equinox ทำงานได้ดีกว่าในแบบที่เฟรมเวิร์ก JAX ควรจะเป็นมาตลอด
    • ถ้าเทียบ equinox กับ flax ตามที่สรุปไว้ในเอกสารของ equinox แล้ว equinox ดีกว่า
    • เป็นเรื่องดีที่ผู้ดูแล ecosystem ของ JAX เริ่มตระหนักถึงความนิยมของ equinox และปรับตัวตาม แต่ก็หวังว่า Google และทีม flax จะให้การสนับสนุนอย่างเป็นทางการมากกว่านี้
    • หากอยากลองใช้ JAX ขอแนะนำให้ใช้ equinox
  • ขอบคมที่ต้องระวัง
    • ด้วยเหตุผลด้านการออกแบบ API และข้อจำกัดของ XLA ทำให้ JAX มี "ขอบคม" ที่ต้องระวัง
    • มีการอธิบายเรื่องนี้ไว้อย่างกระชับมากในเอกสารที่เขียนได้ดี
    • แนะนำให้อ่านอย่างน้อยหนึ่งครั้งก่อนใช้ JAX
    • การ RTFM ย่อมช่วยประหยัดเวลาและพลังงานได้มากเสมอ

บทสรุป

  • โพสต์บล็อกนี้มีไว้เพื่อแก้ความเชื่อที่ถูกพูดซ้ำ ๆ ว่า PyTorch เหมาะที่สุดกับเวิร์กโหลดงานวิจัยจริง โดยเฉพาะบน GPU ซึ่งเป็น Myth ที่พบได้บ่อย มันไม่จริงอีกต่อไปแล้ว
  • อันที่จริง ผู้เขียนถึงขั้นยืนยันอย่างสุดโต่งว่าการพอร์ตโค้ด PyTorch ทั้งหมดไปเป็น JAX จะเป็นประโยชน์อย่างมหาศาลต่อทั้งวงการ
    • การขนานอัตโนมัติ ความสามารถในการทำซ้ำผลลัพธ์ API เชิงฟังก์ชันที่สะอาด ฯลฯ ไม่ใช่ฟีเจอร์เล็กน้อย และจะช่วยโค้ดเบสงานวิจัยจำนวนมากได้มาก
  • หากคุณอยากทำให้วงการนี้ดีขึ้นอีกสักนิด ลองพิจารณาเขียนโค้ดเบสของคุณใหม่ด้วย JAX

8 ความคิดเห็น

 
xguru 2024-08-25

โลกยังคงหมุนต่อไป 555

เปรียบเทียบ PyTorch กับ TensorFlow ในปี 2022

 
hilft 2024-08-21

ผมจะยึด torch กับ onnx ต่อไปครับ

 
flrngel 2024-08-21

บทความที่นักศึกษาปริญญาตรีเขียน.. โห

 
cosine20 2024-08-21

ถ้าไม่มี Huggingface นะ PyTorch นี่จบจริง 555

 
lemonmint 2024-08-19

JAX จงเจริญ! เพิ่งได้ลองใช้เมื่อไม่นานมานี้ และชอบ NNX API มากครับ

 
stareta1202 2024-08-19

ปัญหาใหญ่ที่สุดของ JAX คือมันมาจาก Google นี่แหละ Google ขึ้นชื่อมากเรื่องการทิ้งโอเพนซอร์ส (Tflite, android things, dart, angular, bazel ฯลฯ) แม้แต่ tensorflow เองก็เริ่มอัปเดตไม่ค่อยดีตั้งแต่ช่วงหนึ่งเป็นต้นมา ในทางกลับกัน torch เริ่มต้นจาก Facebook ที่ดูแลโอเพนซอร์สขนาดใหญ่อยู่แล้ว จึงบริหารจัดการได้ดีมาก และตอนนี้ก็อยู่ภายใต้การดูแลของมูลนิธิ torch แล้ว จุดอ่อนของ torch มีส่วนที่จริงแน่นอน แต่ในแง่ของว่าใครจะเป็นคนดูแลโอเพนซอร์สนี้อย่างยั่งยืน JAX ดูเหมือนจะเริ่มต้นมาพร้อมความเสี่ยงก้อนใหญ่แล้ว

 
dalinaum 2024-08-20

อย่างน้อย Dart ก็ดูเหมือนจะยังไปได้ดีกับ Flutter ไปอีกสักพักหนึ่ง

 
ilotoki0804 2024-08-20

Facebook ดูเหมือนจะยังคงมีความรับผิดชอบต่อเทคโนโลยีสแต็กที่ตัวเองใช้งานอยู่ เช่น React, Django และยังคงมีส่วนร่วมอย่างต่อเนื่อง แต่ Google ดูเหมือนว่าถ้าอะไรเริ่มตกยุคขึ้นมานิดหน่อยก็จะทิ้งเหมือนรองเท้าเก่า...