- โมเดล Diffusion ถูกใช้เกินกว่า การสร้างภาพ ไปยังปัญหาที่ต้องสุ่มตัวอย่างจากการแจกแจงหลายโหมด เช่น ออดิโอ วิดีโอ 3D การออกแบบโปรตีน และการวางแผนเส้นทางหุ่นยนต์ โดยบทเรียนนี้เชื่อมโยงการฝึกและการสุ่มตัวอย่างจากมุมมองของการเพิ่มประสิทธิภาพ
- กระบวนการฝึกสร้าง (x_\sigma=x_0+\sigma\epsilon) โดยผสม noise เข้ากับข้อมูล และลดค่า mean squared error เพื่อให้โครงข่ายประสาท (\epsilon_\theta(x,\sigma)) ทำนาย ทิศทางของ noise
- denoiser ที่ฝึกแล้วถูกตีความได้ว่าเป็น การฉายโดยประมาณ ไปยังชุดข้อมูล (\mathcal{K}) และ denoiser อุดมคติเชื่อมโยงกับ gradient ของฟังก์ชันระยะทางกำลังสองที่ถูก (\sigma)-smoothing
- การสุ่มตัวอย่างแบบ DDIM มองได้ว่าเป็น gradient descent โดยประมาณ สำหรับ (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2) และ schedule ของ (\sigma_t) เป็นตัวกำหนดจำนวนรอบทำซ้ำกับต้นทุนการประเมิน denoiser
- เมื่อรวมการอัปเดตการประมาณ gradient กับการเพิ่ม noise จะสามารถจัดการ DDIM, DDPM และ sampler ที่ผู้เขียนปรับปรุงขึ้น ได้ร่วมกันด้วย พารามิเตอร์
gamและmuและต่อยอดไปยังตัวอย่าง toy model กับ latent diffusion
โมเดล Diffusion จากมุมมองการเพิ่มประสิทธิภาพ
- โมเดล Diffusion โดดเด่นในการสร้างตัวอย่างจากการแจกแจงหลายโหมด และถูกนำไปใช้ไม่เพียงกับเครื่องมือสร้างภาพจากข้อความอย่าง Stable Diffusion แต่ยังรวมถึงออดิโอ วิดีโอ การสร้าง 3D การออกแบบโปรตีน และการวางแผนเส้นทางหุ่นยนต์
- พื้นฐานเชิงทฤษฎีของบทเรียนนี้คือ การตีความเชิงการเพิ่มประสิทธิภาพ จาก บทความ ICML 2024 และ บทความที่เกี่ยวข้อง
- การใช้งานอ้างอิงจาก
smalldiffusionเป็นหลัก โดยโค้ดในบทความถูกทำให้ง่ายกว่าคลังต้นฉบับเพื่อใช้ในการสอน
การฝึก: ทำนายทิศทางของ noise
- โมเดล Diffusion มีเป้าหมายเพื่อเรียนรู้ชุดข้อมูล (\mathcal{K}) จากตัวอย่างฝึก และสร้างตัวอย่างจากชุดนั้น
- หากเป็นภาพ (\mathcal{K} \subset \mathbb{R}^{c\times h \times w}) คือชุดค่าพิกเซลที่สอดคล้องกับภาพที่สมจริง
- กรอบเดียวกันนี้ใช้ได้กับโดเมนไม่ต่อเนื่อง เช่น ออดิโอ วิดีโอ วิถีหุ่นยนต์ และข้อความ
- ขั้นตอนการฝึกมองได้เป็นสามขั้น
- สุ่มตัวอย่าง (x_0 \sim \mathcal{K}), (\sigma), และ (\epsilon \sim N(0,I))
- สร้าง ข้อมูลที่ผสม noise ด้วย (x_\sigma=x_0+\sigma\epsilon)
- ลด squared loss เพื่อให้ (\epsilon_\theta(x_\sigma,\sigma)) ทำนาย (\epsilon)
- ในโค้ด
training_loopสร้างsigmaและepsด้วยgenerate_train_sampleสำหรับแต่ละ batchx0และ optimize ค่า MSE ระหว่างเอาต์พุตของmodel(x0 + sigma * eps, sigma)กับeps - แทนที่จะสุ่ม (\sigma) แบบ uniform จากช่วงต่อเนื่อง จะสุ่มจาก schedule ของ (\sigma) ที่ discretize เป็นค่า (N) ค่า
- คลาส
Scheduleห่อรายการsigmasที่เป็นไปได้ และสุ่มค่าเป็นราย batch ระหว่างการฝึก - ตัวอย่างในบทความใช้
ScheduleLogLinear(N, sigma_min=0.02, sigma_max=10) ScheduleDDPMเป็น schedule สำหรับโมเดล diffusion ใน pixel space ส่วนScheduleLDMเป็น schedule สำหรับโมเดล latent diffusion เช่น Stable Diffusion
- คลาส
ตัวอย่าง toy แบบ Swissroll
- toy dataset เป็นชุดจุดรูปเกลียวที่ใช้ในหนึ่งในบทความ diffusion ยุคแรก คือ Sohl-Dickstein et al. 2015 โดยมี (\mathcal{K}\subset\mathbb{R}^2)
- สำหรับ dataset แบบง่าย จะ implement denoiser ด้วย MLP
- อินพุตคือค่าที่นำ (x\in\mathbb{R}^2) มาต่อกับ embedding 2 มิติของ (\sigma)
- เอาต์พุตคือค่าทำนาย noise (\epsilon\in\mathbb{R}^2)
- โมเดล diffusion จำนวนมากใช้ sinusoidal positional embedding สำหรับ (\sigma) แต่ในตัวอย่างนี้ embedding 2 มิติแบบง่ายก็ทำงานได้ดี
- การตั้งค่าฝึกของตัวอย่างใช้
ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)และepochs=15000 - denoiser ที่ฝึกแล้วสามารถแสดงภาพเป็น vector field ได้ด้วยการวาด (x-\sigma\epsilon_\theta(x,\sigma))
- เมื่อ (\sigma) มีค่ามาก denoiser มีแนวโน้มทำนายค่าเฉลี่ยของข้อมูล
- เมื่อ (\sigma) ต่ำและอินพุต (x) อยู่ใกล้ข้อมูล จะทำนาย data point จริง
ตีความ Denoising เป็นการฉาย
- ฟังก์ชันระยะทางสำหรับชุดข้อมูล (\mathcal{K}) นิยามเป็น (\mathrm{dist}_{\mathcal{K}}(x)=\min{|x-x_0|:x_0\in\mathcal{K}})
- การฉายของ (x) หรือ (\mathrm{proj}_{\mathcal{K}}(x)) คือชุดจุดใน (\mathcal{K}) ที่ทำให้ได้ระยะทางนี้
- หาก (\mathcal{K}) เป็นเซตปิด, (x\notin\mathcal{K}) และการฉายมีเพียงหนึ่งเดียว gradient ของฟังก์ชันระยะทางกำลังสองจะเป็น (x-\mathrm{proj}_{\mathcal{K}}(x))
- เนื่องจากฟังก์ชันระยะทาง (\mathrm{dist}_{\mathcal{K}}) ไม่ได้ differentiable ทุกจุด จึงใช้ softmin แทน
minเพื่อนำฟังก์ชันระยะทางกำลังสองที่ smoothing ด้วย (\sigma) เข้ามา - gradient ของฟังก์ชันระยะทางที่ smoothing แล้วจะชี้ไปทาง ค่าเฉลี่ยถ่วงน้ำหนัก ของจุดต่าง ๆ ใน (\mathcal{K}) ตามค่าน้ำหนักที่กำหนดโดย (x)
Denoiser อุดมคติและโมเดล relative error
- denoiser อุดมคติ (\epsilon^*) คือ denoiser ที่ minimize loss การฝึกได้อย่างแม่นยำที่ (\sigma) หนึ่ง ๆ
- หากข้อมูลเป็นการแจกแจงแบบ uniform ไม่ต่อเนื่องบนเซตจำกัด (\mathcal{K}) denoiser อุดมคติสามารถเขียนเป็นรูปแบบปิดได้
- น้ำหนักของแต่ละ data point ถูกกำหนดตามระยะห่างระหว่าง (x_\sigma) กับจุดนั้น
- สำหรับ dataset ขนาดเล็กสามารถคำนวณโดยตรงด้วย
IdealDenoiser
- สำหรับ toy data denoiser อุดมคติจะมุ่งไปยังค่าเฉลี่ยของข้อมูลเมื่อ (\sigma) ใหญ่ และมุ่งไปยัง data point ที่ใกล้ที่สุดเมื่อ (\sigma) เล็ก
- ทฤษฎีบทหลักให้ความสัมพันธ์ว่า สำหรับทุก (\sigma>0), (x\in\mathbb{R}^n), (\frac{1}{2}\nabla_x \mathrm{dist}^2_{\mathcal{K}}(x,\sigma)=\sigma\epsilon^*(x,\sigma))
- โมเดล relative error ใช้เงื่อนไขว่า (x-\sigma\epsilon_\theta(x,\sigma)) ประมาณค่า (\mathrm{proj}_{\mathcal{K}}(x)) ได้ดี
- ใช้ได้เมื่อ (\sqrt{n}\sigma) ประมาณค่า (\mathrm{dist}_{\mathcal{K}}(x)) ได้ดีภายในค่าคงที่เท่าตัว
- สมมติว่า error ถูกจำกัดไว้ไม่เกิน (\eta\mathrm{dist}_{\mathcal{K}}(x))
- ในระดับ noise ต่ำ ภายใต้ manifold hypothesis noise ที่เพิ่มเข้ามาส่วนใหญ่จะตั้งฉากกับ data manifold ดังนั้น denoising จึงประมาณการฉาย
- ในระดับ noise สูง หาก (\sigma) ใหญ่กว่าเส้นผ่านศูนย์กลางของ (\mathcal{K}) แม้ denoiser ที่ทำนายค่าเฉลี่ยถ่วงน้ำหนักของข้อมูลก็มี relative error เล็ก
- CIFAR-10 มีขนาดที่สามารถคำนวณ denoiser อุดมคติได้ และในการทดลอง relative error ระหว่างการฉายที่แม่นยำบน trajectory การสุ่มตัวอย่างกับเอาต์พุตของ denoiser อุดมคติมีค่าน้อย
การสุ่มตัวอย่าง: denoising แบบทำซ้ำและ DDIM
- เมื่อมี denoiser ที่ฝึกแล้ว จาก (x_t) ที่ผสม noise และระดับ noise (\sigma_t) จะทำนาย (x_0) ด้วย (\hat{x}0^t=x_t-\sigma_t\epsilon\theta(x_t,\sigma_t))
- จุดเริ่มต้นกำหนด (\sigma_T) ให้ใหญ่เมื่อเทียบกับเส้นผ่านศูนย์กลางของ (\mathcal{K}) และสุ่ม (x_T) อย่างอิสระจาก (N(0,\sigma_T)) เพื่อให้อยู่ไกลจาก (\mathcal{K})
- ที่ noise สูง การเรียก denoiser เพียงครั้งเดียวแม้จะมี relative error เล็ก แต่อาจมี absolute error ใหญ่ และการทำนายของ denoiser อุดมคติจะใกล้กับค่าเฉลี่ยของข้อมูล
- ดังนั้นการสุ่มตัวอย่างจึงเรียก denoiser ซ้ำตาม schedule ของ (\sigma_t) เพื่อสร้างลำดับ (x_T,\ldots,x_0)
- การอัปเดต (x_{t-1}=x_t-(\sigma_t-\sigma_{t-1})\epsilon_\theta(x_t,\sigma_t)) เทียบเท่ากับ อัลกอริทึมสุ่มตัวอย่าง DDIM แบบ deterministic หลังการแปลงพิกัด
- หลักฐานความเทียบเท่ากับ DDIM อยู่ใน Appendix A ของบทความ
มอง DDIM เป็นการ minimize ระยะทาง
- DDIM ถูกตีความได้ว่าเป็น gradient descent โดยประมาณ สำหรับ (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2)
- step size คือ (1-\sigma_{t-1}/\sigma_t)
- (\nabla f(x_t)) ถูกประมาณด้วย (\epsilon_\theta(x_t,\sigma_t))
- schedule ของ (\sigma_t) กำหนดจำนวนและขนาดของ gradient step ระหว่างการสุ่มตัวอย่าง
- หาก step น้อยเกินไป (\mathrm{dist}_{\mathcal{K}}(x_t)) อาจไม่ลดลงและอาจไม่ลู่เข้า
- หากใช้ step เล็กจำนวนมาก จำนวนครั้งที่ประเมิน denoiser จะเพิ่มขึ้น ทำให้ต้นทุนคำนวณสูงขึ้น
- admissible schedule คือ schedule ที่ทำให้ (\sqrt{n}\sigma_t) สอดคล้องกับ (\mathrm{dist}_{\mathcal{K}}(x_t)) ภายในค่าคงที่เท่าตัวในแต่ละรอบทำซ้ำ
- ลำดับ (\sigma_t) แบบ log-linear ที่ลดลงเชิงเรขาคณิตเป็น admissible schedule
- ตามทฤษฎีบท หาก (\nabla\mathrm{dist}{\mathcal{K}}(x)) มีอยู่ที่ (x_t) ที่สร้างด้วย DDIM และ (\mathrm{dist}{\mathcal{K}}(x_T)=\sqrt{n}\sigma_T), (x_t) จะถูกสร้างขึ้นด้วย gradient descent ของฟังก์ชันระยะทางกำลังสอง และรักษา (\mathrm{dist}_{\mathcal{K}}(x_t)/\sqrt{n}\approx\sigma_t)
- ใน toy example มีการ implement DDIM sampler 20 steps โดย subsample จาก schedule log-linear เดิม ซึ่งตัวอย่างส่วนใหญ่ใกล้กับข้อมูลต้นฉบับ แต่ยังมีพื้นที่ให้ปรับปรุง
Sampler ที่ปรับปรุงโดยอิงการประมาณ gradient
- ใช้ข้อเท็จจริงที่ว่า (\nabla\mathrm{dist}{\mathcal{K}}(x)) ไม่เปลี่ยนแปลงระหว่าง (x) กับ (\mathrm{proj}{\mathcal{K}}(x)) โดยใช้อัปเดตที่ผสมการประมาณปัจจุบันกับการประมาณก่อนหน้า
- การอัปเดต (\bar{\epsilon}t=\gamma\epsilon\theta(x_t,\sigma_t)+(1-\gamma)\epsilon_\theta(x_{t+1},\sigma_{t+1})) เป็นวิธีแก้ error ของ step ก่อนหน้าด้วยการประมาณปัจจุบัน
- ในตัวอย่าง toy model วิธีนี้ลู่เข้าเร็วกว่า DDIM และตัวอย่างใกล้กับข้อมูลต้นฉบับมากขึ้น
- เมื่อเทียบกับ DDIM sampler นี้ตีความได้ว่าเพิ่ม momentum และแม้ trajectory อาจ overshoot ได้ แต่สามารถลู่เข้าได้เร็วกว่า
- การเพิ่ม noise ระหว่างกระบวนการสร้างช่วยปรับปรุงคุณภาพการสุ่มตัวอย่างในเชิงประจักษ์
- เพื่อรักษา schedule ของ (\sigma_t) เดิม จะ denoise ลงไปถึง (\sigma_{t'}) ที่เล็กกว่า แล้วเพิ่ม noise (w_t\sim N(0,I)) กลับเข้าไป
- เมื่อ (\mu=\frac{1}{2}) จะกู้คืน DDPM sampler ได้อย่างแม่นยำ
- การอัปเดตรวม (x_{t-1}=x_t-(\sigma_t-\sigma_{t'})\bar{\epsilon}_t+\eta w_t) ทำให้ครอบคลุม sampler ทั้งสามแบบ
- DDIM:
gam=1, mu=0 - DDPM:
gam=1, mu=0.5 - gradient estimation sampler:
gam=2, mu=0
- DDIM:
โมเดลที่ใหญ่ขึ้นและแหล่งอ้างอิง
- โค้ดฝึกข้างต้นใช้ได้ไม่เพียงกับ toy data แต่ยังใช้ฝึกโมเดล image diffusion ตั้งแต่ศูนย์ได้ด้วย
- ตัวอย่าง FashionMNIST เป็นตัวอย่างที่ฝึกบน dataset FashionMNIST และทำคะแนน FID ได้อันดับ 2 บน ลีดเดอร์บอร์ด Papers with Code
- โค้ดสุ่มตัวอย่างสามารถใช้กับโมเดล latent diffusion ที่ฝึกไว้ล่วงหน้าได้โดยไม่ต้องแก้ไข
- ตัวอย่างใช้
ScheduleLDM(1000)และModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base') - ตั้ง text condition เป็น
An astronaut riding a horseสุ่มตัวอย่างด้วย (\sigma) 50 steps แล้ว decode latent
- ตัวอย่างใช้
- ผลของเทอม momentum (\gamma) ถูกแสดงเป็นภาพเปรียบเทียบในการสร้างภาพจากข้อความความละเอียดสูง
- แหล่งข้อมูลเพิ่มเติมที่น่าอ่าน
- What are diffusion models: แนะนำโมเดล Diffusion จากมุมมองเวลาไม่ต่อเนื่องที่ย้อนกลับ Markov process
- Generative modeling by estimating gradients of the data distribution: แนะนำโมเดล Diffusion จากมุมมองเวลาต่อเนื่องที่ย้อนกลับสมการเชิงอนุพันธ์สุ่ม
- The annotated diffusion model: อธิบายรายละเอียดการ implement โมเดล Diffusion ด้วย PyTorch
1 ความคิดเห็น
ความคิดเห็นใน Hacker News
ถ้ามีคำถามก็ยินดีตอบ
โดยเฉพาะการอภิปรายเรื่อง trajectory ดีมาก เพราะช่วยสร้างแรงจูงใจในการทำความเข้าใจส่วนที่หลายคนมักลำบากในหัวข้ออย่าง scheduler แม้จะไม่สมบูรณ์เท่าบทความของ Song หรือ Lilian แต่เข้าถึงง่ายกว่ามาก จึงคิดว่าจะแนะนำให้คนอื่นอ่าน
อนึ่ง เพื่อนเคยเขียน implementation แบบ minimal ของ diffusion ไว้ก่อนหน้านี้ ซึ่งในมุมมอง DDPM ถือว่า “สมบูรณ์” กว่าเล็กน้อยและมีประโยชน์: https://github.com/VSehwag/minimal-diffusion/
จากมุมมองที่เคยทดลองกระบวนการ sampling ใน Stable Diffusion มาบ้าง ผม/ฉันก็อยากเห็นการเปรียบเทียบ เวลา convergence และจำนวน step เทียบกับ DDIM ด้วย อยากรู้ว่ามีความสัมพันธ์ระหว่าง momentum, convergence และ error หรือไม่ เช่น ถ้ามีการเปรียบเทียบว่า momentum sampler 16 step แทบเทียบเท่ากับ DDIM 20 step ± error term อะไรทำนองนี้ก็น่าจะดี
get_sigma_embeds(batches, sigma)ดูเหมือนจะไม่ได้ใช้ input ตัวแรก สงสัยว่าตั้งใจจะ broadcastsigmaให้เป็นรูป(batches, 1)หรือเปล่าลงลึกในรายละเอียดทางคณิตศาสตร์มากกว่าเยอะ แต่ก็มาพร้อม implementation แบบ minimal ที่เข้าใจง่ายมากและมีโค้ดน้อยกว่า 500 บรรทัด
ถ้าขยายไปเป็นเวอร์ชัน diffusion transformer ที่ขับเคลื่อน Sora และโมเดลสร้างวิดีโออื่น ๆ ได้ก็คงดี น่าจะรวมบทความนี้กับ https://jaykmody.com/blog/gpt-from-scratch/ แล้วทำเป็นบทความปูพื้น “Diffusion Transformer from Scratch” ได้
ในทางกลับกัน ถ้าอยากเจาะลึกจริง ๆ แนะนำให้อ่านงานของ Kingma, Gao, Ricky Tian Qi Chen และลูกศิษย์ของ Max Welling (Tomczak เป็น postdoc, Hoogeboom เป็นต้น) รวมถึงงานของ Aapo Hyvärinen ผู้มีผลงานสำคัญแต่ไม่ค่อยถูกพูดถึง ตัวอย่างงานของ Kingma & Gao ที่ค่อนข้างเบาและเกี่ยวข้องกับเปเปอร์ SD3 อยู่ที่นี่: https://arxiv.org/abs/2303.00848
จุดที่น่าเสียดายคือมันเข้าถึงยากเพราะต้องรู้และเข้าใจงานก่อนหน้าค่อนข้างมาก แต่ก็ยากจะเรียกสิ่งนี้ว่าเป็นคำวิจารณ์ที่มีน้ำหนัก เพราะนี่คืองานวิจัย ไม่ใช่สื่อการสอนสำหรับคนทั่วไป
n_embdส่วนกระบวนการ diffusion เองยังคงเดิมได้[1] https://yang-song.net/blog/2021/score/
[2] https://lilianweng.github.io/posts/2021-07-11-diffusion-mode...
ในมุมมองของเรา เหตุผลที่โมเดล diffusion ฝึกได้ง่ายคือมันใช้เป้าหมายการเรียนรู้ที่ทำนาย gradient ของฟังก์ชันระยะทางที่ถูกทำให้เรียบ แทนที่จะทำนาย gradient ของฟังก์ชันระยะทางที่แม่นยำ การ sampling ของโมเดล diffusion คล้ายกับการเดินหลาย step ตาม gradient โดยประมาณ
ถ้าอยากเข้าใจโมเดล diffusion ให้ลึกขึ้น แนะนำให้อ่านบล็อกเหล่านี้ทั้งหมดและเรียนรู้การตีความที่แตกต่างกัน
อย่างไรก็ตาม แนวทางของบทความนี้ดูเหมือนจะเปิดทางให้ทำการทดลองที่น่าสนใจกว่า เช่น การวิเคราะห์ error ของ denoiser
[1] https://arxiv.org/pdf/2305.03486.pdf
ตัวอย่างเช่น ทำไม image generator ถึงสร้างคีย์เปียโนได้ยาก? ดูเหมือนว่าการสร้างโครงสร้างที่คีย์สีดำสลับกันเป็นกลุ่มสองปุ่มและสามปุ่ม ต้องแสดงข้อจำกัดเรื่องระยะห่างระดับกลางให้ดีขึ้น