AlphaFold ฉบับมีภาพประกอบ
(elanapearl.github.io)- AlphaFold3 มีเป้าหมายที่จะ ทำนายจากลำดับเท่านั้น สำหรับคอมเพล็กซ์ที่มีโปรตีน กรดนิวคลีอิก และโมเลกุลขนาดเล็กอยู่ร่วมกัน โดยไปไกลกว่าการทำนายโปรตีนเดี่ยว ทำให้การแทนอินพุตและการทำ tokenization ซับซ้อนกว่า AF2 มาก
- อินพุตแบ่งเป็น single/pair representation ระดับโทเคน, representation ระดับอะตอม, MSA และเทมเพลต โดยกรดอะมิโนมาตรฐานและนิวคลีโอไทด์มาตรฐานถือเป็น 1 โทเคน ส่วน residue ที่ไม่เป็นมาตรฐานและโมเลกุลอื่น ๆ จัดการเป็น 1 โทเคนต่ออะตอม
- trunk สำหรับเรียนรู้ representation ปรับปรุง single representation s และ pair representation z ซ้ำ ๆ ผ่านโมดูลเทมเพลต, โมดูล MSA และ Pairformer ด้วย pair-bias attention, triangle operation และ recycling
- การทำนายโครงสร้างใช้ conditional diffusion model กับพิกัดอะตอม แทน Invariant Point Attention ของ AF2 และสร้างการอัปเดตพิกัดของทุกอะตอมผ่าน rotation/translation augmentation และ denoising
- การฝึกผสาน distogram, diffusion และ confidence loss และเรียนรู้ซ้ำแม้แต่ representation แบบ unfolded ของบริเวณความเชื่อมั่นต่ำผ่าน cross-distillation ที่ใช้ผลลัพธ์จาก AF2 และ AF-Multimer
ขอบเขตอินพุตและ pipeline โดยรวมของ AlphaFold3
- เป้าหมายของ AlphaFold3 ไม่ได้หยุดอยู่แค่การทำนายลำดับโปรตีนเดี่ยวแบบ AF2 หรือจัดการเฉพาะโปรตีนคอมเพล็กซ์แบบ AF-Multimer แต่คือการ ทำนายจากลำดับเท่านั้น สำหรับโครงสร้างที่โปรตีนจับกับโปรตีนอื่น กรดนิวคลีอิก หรือโมเลกุลขนาดเล็กได้ตามตัวเลือก
- ความหมายของ “โทเคน” เปลี่ยนไปตามชนิดของอินพุต
- โปรตีน: กรดอะมิโนมาตรฐาน 1 ตัวเป็น 1 โทเคน
- DNA/RNA: นิวคลีโอไทด์มาตรฐาน 1 ตัวเป็น 1 โทเคน
- กรดอะมิโน/นิวคลีโอไทด์ที่ไม่เป็นมาตรฐาน: อะตอม 1 อะตอมเป็น 1 โทเคน
- โมเลกุลอื่น ๆ: อะตอม 1 อะตอมเป็น 1 โทเคน
- โปรตีนที่ประกอบด้วยกรดอะมิโนมาตรฐาน 35 ตัวอาจมีอะตอมจริงมากกว่า 600 อะตอม แต่ถูกแทนด้วย 35 โทเคน ขณะที่ ligand ที่มี 35 อะตอมจะถูกแทนด้วย 35 โทเคน
- โมเดลประกอบด้วยสามขั้นตอนหลัก
- Input Preparation: แปลงลำดับที่ผู้ใช้ป้อน รวมถึงลำดับและโครงสร้างที่เกี่ยวข้องซึ่งค้นพบ ให้เป็น tensor เชิงตัวเลข
- Representation Learning: อัปเดต single representation และ pair representation ด้วย attention หลายรูปแบบ
- Structure Prediction: ทำนายโครงสร้างด้วย conditional diffusion
- โปรตีนคอมเพล็กซ์ถูกเก็บไว้เป็นหลักใน representation สองแบบ
- single representation: แทนตัวโทเคนทั้งหมดในคอมเพล็กซ์
- pair representation: แทนความสัมพันธ์ เช่น ระยะทางและปฏิสัมพันธ์แฝง ระหว่างคู่โทเคนทั้งหมด
- มิติช่องหลักคือ
c_z=128,c_m=64,c_atom=128,c_atompair=16,c_token=768,c_s=384
การเตรียมอินพุต: กระบวนการแปลงลำดับเป็น tensor 6 ตัว
- อินพุตที่ผู้ใช้ให้มาจะถูกแปลงเป็น tensor 6 ตัวสำหรับป้อนเข้า model trunk
- s: token-level single representation
- z: token-level pair representation
- q: atom-level single representation
- p: atom-level pair representation
- m: MSA representation
- t: template representation
-
การค้นหา MSA และเทมเพลต
- AF3 ค้นหาลำดับที่คล้ายกันสำหรับลำดับโปรตีนและ RNA แล้วจัดเป็น MSA และรวมโครงสร้างที่เกี่ยวข้องไว้เป็น template
- MSA จัดเรียงลำดับโปรตีนที่คล้ายกันซึ่งพบในหลายสปีชีส์ เพื่อให้โมเดลเห็นรูปแบบการอนุรักษ์ของตำแหน่งหนึ่ง ๆ และสหสัมพันธ์ของการเปลี่ยนแปลงระหว่างตำแหน่งต่าง ๆ
- โครงสร้างที่ทราบของโปรตีนที่คล้ายกันถูกใช้เพื่อประมาณโครงสร้างของโปรตีน query เช่นเดียวกับ homology modeling
- การค้นหาไม่รวมการเรียนรู้ และใช้วิธีที่อิง HMM
- ใช้
jackhmmer,HHBlits,nhmmerเพื่อค้นหาฐานข้อมูลโปรตีนและ RNA หลายชุด และใช้hmmsearchเพื่อค้นหาลำดับที่คล้ายกันใน Protein Data Bank - ขนาด MSA ถูกจำกัดไว้ที่
N_MSA < 2^14เนื่องจากความซับซ้อนในการคำนวณ - ในแต่ละ protein chain จะเลือกโครงสร้างคุณภาพสูง และสุ่มตัวอย่างได้สูงสุด 4 รายการเป็น template
- องค์ประกอบการค้นหาที่เพิ่มใหม่เมื่อเทียบกับ AF-Multimer คือ รวมลำดับ RNA เป็นเป้าหมายการค้นหาด้วย
-
วิธีแทนเทมเพลต
- คำนวณระยะทางแบบ Euclidean ระหว่างคู่โทเคนแต่ละคู่จากโครงสร้าง 3D ของ template
- โทเคนที่มีหลายอะตอมใช้ “center atom” เป็นตัวแทน
- กรดอะมิโน: อะตอม
Cα - นิวคลีโอไทด์มาตรฐาน: อะตอม
C1'
- กรดอะมิโน: อะตอม
- ค่าระยะทางไม่ได้เป็นค่าต่อเนื่อง แต่ถูกทำให้เป็นแบบไม่ต่อเนื่องด้วย distogram
- 38 bin ตั้งแต่ 3.15Å ถึง 50.75Å
- bin เพิ่มเติมอีก 1 bin สำหรับระยะทางที่มากกว่านั้น
- ใน distogram มีการเพิ่มข้อมูล chain, สถานะว่าโทเคนนั้นถูก resolved ใน crystal structure หรือไม่ และข้อมูล local distance ภายในกรดอะมิโนแต่ละตัว
- template matrix ถูก masking ให้ดูเฉพาะระยะทางภายใน chain เดียวกัน และไม่ได้พยายามได้ข้อมูล inter-chain interaction จากการเลือก template
การแทนค่าระดับอะตอมและ Atom Transformer
-
reference conformer และการแทนค่าระดับอะตอม
- เพื่อสร้าง single representation ระดับอะตอม q จะคำนวณ reference conformer สำหรับกรดอะมิโน นิวคลีโอไทด์ และ ligand แต่ละตัว
- conformer คือการจัดเรียงอะตอมแบบ 3D ของโมเลกุลที่สร้างขึ้นโดยการสุ่มตัวอย่างการหมุนรอบพันธะเดี่ยว
- กรดอะมิโนมาตรฐานใช้ conformer พลังงานต่ำที่หาได้จาก lookup ส่วนโมเลกุลขนาดเล็กสร้าง 3D conformer ด้วย RDKit’s ETKDGv3
- นำตำแหน่งสัมพัทธ์ของ conformer, ประจุอะตอม, เลขอะตอม, identifier ฯลฯ มารวมกันเพื่อสร้าง atom-level single representation c
- ใช้ c เพื่อ initialize atom-level pair representation p และใช้ mask v เพื่อให้มีเฉพาะระยะห่างระหว่างอะตอมที่คำนวณจาก reference conformer
- q เริ่มจากสำเนาของ c แล้วจึงถูกอัปเดตใน Atom Transformer
-
บทบาทของ Atom Transformer
- Atom Transformer เป็นโมดูลที่ทำ attention ระดับอะตอม โดยใช้ p และ representation เดิม c เพื่ออัปเดต q
- c จะไม่ถูกอัปเดต และถูกใช้คล้าย residual connection ที่ชี้กลับไปยัง representation ตั้งต้น
- โครงสร้างพื้นฐานคล้าย transformer โดยมี LayerNorm, attention และ MLP transition แต่แต่ละขั้นตอนถูกปรับด้วยอินพุตเพิ่มเติมอย่าง c และ p
-
Adaptive LayerNorm
- Adaptive LayerNorm ไม่ได้เรียนรู้
gamma,betaแบบคงที่ แต่สร้างgamma,betaจากอินพุตเสริม - ใน Atom Transformer สิ่งที่ถูก rescale คือ q และพารามิเตอร์สำหรับ rescale ถูกทำนายจากอินพุตเสริม c
- Adaptive LayerNorm ไม่ได้เรียนรู้
-
Attention with Pair Bias
- Atom-level attention with pair bias เป็นส่วนขยายของ self-attention
- query, key, value ทั้งหมดมาจาก single representation q แต่หลังจากทำ query-key dot product แล้ว จะบวก linear projection ของ pair representation p เข้าไปเป็น bias
- ข้อมูลไหลจาก pair representation ไปยัง q แต่ในขั้นตอนนี้จะไม่ใช้ข้อมูลของ q เพื่ออัปเดต p
- gate ที่สร้างโดยนำ projection เพิ่มเติมผ่าน sigmoid จะถูกคูณกับผลลัพธ์ของ attention เพื่อควบคุมว่าจะเหลือข้อมูลใดไว้ใน residual stream
- จำนวนอะตอมอาจมากกว่าจำนวนโทเคนมาก จึงใช้ Sequence-local atom attention แทน full attention
- local group ขนาด 32 อะตอมสามารถ attend ไปยังอะตอมอื่นได้ 128 อะตอม
-
Conditioned Gating และ Transition
- Conditioned Gating ใช้ gate ที่สร้างจาก atom-level single matrix เดิม c กับข้อมูล
- Conditioned Transition เทียบได้กับ MLP ของ transformer และถูกเรียกว่า conditioned เพราะ Adaptive LayerNorm และ Conditional Gating ขึ้นกับ c
- AF3 ใช้ SwiGLU ใน transition block แทน ReLU
- transition แบบ ReLU ของ AF2 มีโครงสร้าง up-projection 4 เท่า, ReLU, down-projection
- SwiGLU ของ AF3 จะใช้ nonlinearity แบบ swish กับหนึ่งใน up-projection สองตัว จากนั้นนำมาคูณกันแล้ว down-project
การรวมการแทนค่าอะตอมเป็นการแทนค่าโทเคน
- ขั้นตอน representation learning หลังจากนี้ทำงานในระดับ token-level จึงรวม atom-level representation ให้เป็น token-level representation
- หลังจาก projection atom-level representation ไปยังมิติที่ใหญ่ขึ้นแล้ว จะหาค่าเฉลี่ยของอะตอมที่อยู่ในโทเคนเดียวกัน
- การรวมแบบค่าเฉลี่ยนี้ใช้เมื่ออะตอมหลายตัวเชื่อมกับโทเคนเดียว เช่น กรดอะมิโนและนิวคลีโอไทด์มาตรฐาน ส่วนอินพุตที่เป็น 1 โทเคนต่อ 1 อะตอมจะคงไว้ตามเดิม
- token-level single input ยังรวมสถิติที่ได้จาก MSA ด้วย
- ชนิดกรดอะมิโน
- การกระจายของกรดอะมิโนใน MSA ณ ตำแหน่งนั้น
- deletion mean ของโทเคนนั้น
- สำหรับโทเคนที่ไม่มี MSA เช่น อะตอมของ ligand ค่าเหล่านี้จะเป็น 0
- s_inputs ที่สร้างขึ้นเช่นนี้จะผ่าน projection กลายเป็น s_init และถูกอัปเดตในขั้นตอน representation learning
- pair representation z_init เป็นเทนเซอร์ 3 มิติที่เก็บความสัมพันธ์ของ token pair แต่ละคู่ โดย z_i,j แต่ละตัวเป็นเวกเตอร์มิติ
c_z=128 - ในการ initialize z_i,j จะบวก projection ของ s_i, s_j, relative positional encoding และข้อมูล bond ระหว่างโทเคนที่ผู้ใช้ระบุ
การเรียนรู้ representation: Template, MSA, Pairformer
- representation learning คือ trunk ซึ่งกินการคำนวณส่วนใหญ่ของโมเดล และมีเป้าหมายเพื่อปรับปรุง token-level single representation s และ pair representation z
- single sequence representation ไม่ได้หมายถึงลำดับโปรตีนเดียวเท่านั้น แต่หมายถึง sequence ที่นำอะตอมหรือโทเคนทั้งหมดในโครงสร้างมาต่อกัน
-
Template Module
- template แต่ละตัวผ่าน linear projection แล้วนำไปบวกกับ linear projection ของ pair representation z
- matrix ที่รวมแล้วจะผ่าน Pairformer Stack
- ผลลัพธ์จาก template หลายตัวจะถูกนำมาเฉลี่ย แล้วผ่าน linear layer อีกครั้ง
- linear layer สุดท้ายใช้ ReLU และเป็นหนึ่งในไม่กี่ตำแหน่งใน AF3 ที่ใช้ ReLU เป็น nonlinearity
-
MSA Module
- MSA Module คล้ายกับ Evoformer ของ AF2 มาก และปรับปรุง MSA representation m กับ pair representation z พร้อมกัน
- ไม่ได้ใช้ MSA row ทั้งหมด แต่ทำ subsampling แล้วบวก projection ของ single representation เข้าไปใน MSA
- Outer Product Mean เป็นการดำเนินการที่ใส่ข้อมูล MSA เข้าไปใน pair representation
- สำหรับ token index
i,jแต่ละคู่ จะคำนวณ outer product ของ m_s,i และ m_s,j สำหรับ evolutionary sequence ทั้งหมด - นำไปเฉลี่ยตลอด sequence ทั้งหมด flatten แล้ว projection จากนั้นบวกเข้าไปใน z_i,j
- เป็นจุดเดียวในโมเดลที่ข้อมูลถูกแชร์กันระหว่าง evolutionary sequence
- สำหรับ token index
- Row-wise gated self-attention using only pair bias ใช้ pair representation เพื่ออัปเดต MSA
- แทนที่จะสร้าง attention score ด้วย query และ key จะ projection pair representation z เป็น matrix แล้วใช้เป็น attention score ระหว่างโทเคน
- เนื่องจากใช้แยกกันกับ MSA row แต่ละแถว ในขั้นตอนนี้จึงไม่มีการแชร์ข้อมูลระหว่าง evolutionary sequence
- ตอนท้ายของ MSA module จะอัปเดต pair representation อีกครั้งด้วย triangle update และ triangle attention
Pairformer และการดำเนินการแบบ triangle
- หลังจากอัปเดต z ด้วย Template และ MSA แล้ว จะไม่ใช้ template และ MSA อีกต่อไป โดยจะป้อนเฉพาะ s และ z เข้าไปยัง Pairformer
- Pairformer สร้าง s_trunk และ z_trunk ขั้นสุดท้ายผ่านการทำซ้ำ block 48 ชุด
-
สัญชาตญาณของการดำเนินการแบบ triangle
- triangle update และ triangle attention เป็นโครงสร้างที่พยายามสะท้อนสัญชาตญาณของ อสมการสามเหลี่ยม เข้าไปในโมเดล
- แม้ z_i,j ของ pair tensor จะไม่ใช่ระยะทางทางกายภาพโดยตรง แต่เนื่องจากมันบรรจุความสัมพันธ์ระหว่าง token
iและjจึงอัปเดตให้ความสัมพันธ์ทั้งสามของi-j,j-k,i-kสอดคล้องกัน - อสมการสามเหลี่ยมไม่ได้ถูกบังคับโดยตรงภายในโมเดล แต่ถูกชักนำด้วยวิธีดู triplet
(i,j,k)ทั้งหมดแล้วอัปเดต z_i,j - z สามารถมองได้คล้าย directed adjacency matrix จึงแยกประมวลผลทิศทางของ outgoing edge และ incoming edge
-
Triangle Updates
- ใน outgoing update จะอัปเดต z_i,j แต่ละตัวโดยใช้องค์ประกอบอื่นใน row เดียวกันคือ z_i,k และ edge ที่สาม z_j,k
- ในเชิง implementation จะสร้าง projection สามตัวของ z ได้แก่
a,b,gจากนั้นนำ rowiและ rowjมาทำ element-wise multiplication แล้วรวมตามkก่อนใช้ gateg - incoming update เป็นรูปแบบที่สลับ row กับ column โดยอัปเดต z_i,j ผ่านองค์ประกอบอื่นใน column เดียวกันคือ z_k,j และ z_k,i
-
Triangle Attention
- triangle attention เป็นรูปแบบที่เพิ่มหลักการ triangle เข้าไปใน axial attention ซึ่งใช้ attention แยกกันบน row และ column ของ 2D matrix
- ในกรณี “starting node” จะเพิ่ม z_j,k เป็น bias ให้กับการเปรียบเทียบ query-key ของ z_i,j และ z_i,k
- ในกรณี “ending node” จะทำงานตาม column และใช้ z_k,j เป็น bias ให้กับ attention score ของ z_i,j และ z_k,i
-
Single Attention with Pair Bias
- หลังจาก triangle step และ transition block แล้ว single representation s จะถูกอัปเดตด้วย single attention with pair bias ที่ใช้ updated pair representation z
- เนื่องจากทำงานในระดับ token จึงใช้ full attention ไม่ใช่ block-wise sparse attention ที่ใช้ในระดับ atom
การทำนายโครงสร้าง: denoising พิกัดอะตอมด้วย diffusion
-
วิธีพื้นฐานของโมเดล diffusion
- AF3 ทำการทำนายโครงสร้างขั้นสุดท้ายด้วย atom-level diffusion
- diffusion model เรียนรู้โดยเพิ่ม random noise เข้าไปในข้อมูลจริงทีละขั้น และให้โมเดลทำนายว่า noise แบบใดถูกเพิ่มเข้าไป
- ในช่วง inference จะเริ่มจาก random noise อย่างสมบูรณ์ แล้วลบ noise ที่โมเดลทำนายในแต่ละ step เพื่อสร้าง datapoint ที่ผ่านการ denoise แล้ว
- conditional diffusion รับ current noisy generation, การแทนค่า timestep ปัจจุบัน และ condition vector เป็นอินพุต เพื่อสร้างผลลัพธ์ที่ตรงตามเงื่อนไข
- ใน AF3 เป้าหมายของ denoising คือ matrix x ที่บรรจุพิกัด
x,y,zของอะตอมทั้งหมด
-
การเสริมข้อมูลด้วยการหมุนและเลื่อนแทน IPA ของ AF2
- AF3 ไม่ใช้ Invariant Point Attention ของ AF2 แต่จะสุ่มหมุนและเลื่อนคอมเพล็กซ์ทั้งหมดที่กำลังทำนายในแต่ละ timestep
- การเสริมข้อมูลนี้ทำให้โมเดลเรียนรู้ว่าไม่ว่าจะหมุนหรือเลื่อนแบบใดก็ยังเป็นโครงสร้างเดียวกันที่ใช้ได้ และเป็นแนวทางที่ง่ายกว่า IPA ของ AF2
- การหมุนจะถูกใช้โดยมีค่าเฉลี่ยของพิกัดอะตอมทั้งหมดใน generation ปัจจุบันเป็นศูนย์กลาง และ translation จะถูกสุ่มจาก Gaussian
N(0,1)ในแต่ละมิติ - ยังมีการเพิ่ม noise เล็กน้อยให้กับพิกัด เพื่อชักนำให้เกิด generation ที่หลากหลายขึ้น
- ในช่วง inference สามารถให้คะแนน generation หลายชุดด้วย confidence head แล้วส่งคืน generation ที่ได้คะแนนสูงสุดได้
-
สี่ขั้นตอนของ Diffusion Module
- แต่ละ denoising step ใช้ conditioning representation หลายชุด
- เอาต์พุต trunk s_trunk, z_trunk
- representation เริ่มต้น s_inputs, c_inputs ที่สร้างจาก input embedder
- กระบวนการ diffusion ประกอบด้วยสี่ขั้นตอนที่สลับไปมาระหว่างพื้นที่ token และ atom
-
- เตรียม token-level conditioning tensor
-
- เตรียม atom-level conditioning tensor, ใช้ Atom Transformer แล้วรวมกลับเป็น token-level
-
- ใช้ token-level attention
-
- ทำนาย noise update รายอะตอมด้วย atom-level attention
-
- ใน token-level conditioning จะรวม z_trunk กับ relative positional encoding แล้วส่งผ่าน transition block
- ใน single representation จะรวม s_inputs กับ s_trunk และเพิ่ม Fourier embedding ตาม diffusion timestep
- ในขั้น atom-level จะอัปเดต c, p เริ่มต้นด้วย token-level representation ปัจจุบัน และสเกลพิกัดปัจจุบัน x ด้วย data variance เพื่อสร้าง dimensionless coordinate r
- ในขั้น atom-level สุดท้าย linear layer จะ map q ไปยัง
R^3เพื่อสร้าง coordinate update r_update ของอะตอมทั้งหมด - update จะถูก rescale เป็น x_update โดยคำนึงถึง data variance และ noise schedule แล้วนำไปใช้กับพิกัดปัจจุบัน x_l
- แต่ละ denoising step ใช้ conditioning representation หลายชุด
ฟังก์ชัน loss และ confidence head
- loss ทั้งหมดคือผลรวมถ่วงน้ำหนักของสามพจน์
L_loss = L_distogram * α_distogram + L_diffusion * α_diffusion + L_confidence * α_confidence
-
L_distogram
- L_distogram ประเมินความแม่นยำของ distogram ที่คาดการณ์ในระดับ token
- เมื่อสร้างพิกัด token จากพิกัดอะตอม จะใช้พิกัดของ center atom ของแต่ละ token
- ระยะทางใน distogram ถูกจัดการเป็น categorical value และเปรียบเทียบ distogram ที่คาดการณ์กับ distogram จริงด้วย cross entropy
-
L_diffusion
- L_diffusion เป็นผลรวมถ่วงน้ำหนักของหลายพจน์ที่ใช้กับตำแหน่ง atom
- L_MSE คำนวณ mean squared error ระหว่าง position สำหรับทุกอะตอม ไม่ใช่เฉพาะ center atom และอะตอมของ DNA, RNA, ligand จะถูก upweight
- L_bond เป็นพจน์ MSE เพิ่มเติมเพื่อเพิ่มความแม่นยำของ bond length ของ atom pair ที่อยู่ใน protein-ligand bond
- ในช่วง training stage แรก
α_bond=0จึงถูกนำมาใช้ในภายหลัง - L_smooth_LDDT เป็น loss ที่ทำให้ local distance accuracy นุ่มนวลและหาอนุพันธ์ได้
- ใช้ threshold สี่ค่า ได้แก่ 4Å, 2Å, 1Å, 0.5Å
- atom pair ของ nucleotide จะถูกละเว้นหากอยู่ไกลกว่า 30Å
- atom pair ของ protein หรือ ligand จะถูกละเว้นหากอยู่ไกลกว่า 15Å
-
L_confidence
- L_confidence ฝึกให้โมเดลประมาณความแม่นยำของการคาดการณ์ของตัวเอง แทนที่จะเพิ่มความแม่นยำของโครงสร้างโดยตรง
- ประกอบด้วย loss ที่สอดคล้องกับ confidence metric สี่ประเภท
- pLDDT: local distance accuracy สำหรับอะตอมที่อยู่ใกล้กัน
- PAE: predicted alignment error ของ token pair
- PDE: predicted distance error ระหว่าง token pair
- experimentally resolved prediction: คาดการณ์ว่าแต่ละอะตอมถูก resolved ในโครงสร้างจากการทดลองหรือไม่
- แม้โครงสร้างที่คาดการณ์จะไม่แม่นยำจนทำให้ PAE สูง แต่หากโมเดลคาดการณ์ว่า PAE สูงเช่นกัน PAE loss ดังกล่าวก็อาจต่ำได้
- confidence prediction ถูกสร้างขึ้นในขั้นกลางของ diffusion
- gradient ของ confidence loss จะอัปเดตเฉพาะ confidence prediction head และไม่ส่งผลต่อส่วนอื่นของโมเดล
เทคนิคการเรียนรู้เพิ่มเติมและการเพิ่มประสิทธิภาพ
-
Recycling
- AF3 ใช้ weight recycling เช่นเดียวกับ AF2
- แทนที่จะทำให้โมเดลลึกขึ้น จะนำ weight เดิมกลับมาใช้ซ้ำหลายครั้งเพื่อปรับปรุง representation ทีละขั้น
- diffusion ก็ใช้ข้อมูล timestep ใน inference และนำ weight เดิมกลับมาใช้ซ้ำในทุก timestep จึงมี recycling อยู่ในตัว
-
Cross-distillation
- AF3 ใช้ทั้ง synthetic training data ที่สร้างขึ้นเอง และ synthetic data ที่สร้างโดย AF2 กับ AF-Multimer
- หลังเปลี่ยนมาใช้ generation แบบ diffusion-based เกิดปัญหาว่ารูปแบบ “spaghetti” ที่เคยช่วยแยกแยะบริเวณความเชื่อมั่นต่ำและไร้ระเบียบใน AF2 ทางภาพหายไป
- โดยรวม generation ของ AF2 และ AF-Multimer เข้าใน training data ของ AF3 ทำให้ AF3 เรียนรู้วิธีที่ AF2 แสดง unfolded region ในบริเวณที่ไม่มั่นใจ
- ใน distillation dataset จะลบกรดนิวคลีอิกและโมเลกุลขนาดเล็กที่ AF2 และ AF-Multimer ไม่สามารถจัดการได้ออก
- หลังจากโมเดลก่อนหน้าสร้างโครงสร้างที่คาดการณ์แล้ว เมื่อนำไป alignment กับต้นฉบับ จะเพิ่มโมเลกุลที่เคยลบออกกลับเข้ามา
- หากโมเลกุลที่เพิ่มกลับเข้ามาทำให้เกิด atom clash จะตัดโครงสร้างทั้งหมดออก เพื่อหลีกเลี่ยงการฝึกให้โมเดลยอมรับ clash
-
Cropping และ training stage
- ตัวโมเดลเองไม่มีข้อจำกัดแบบชัดเจนต่อความยาวลำดับอินพุต แต่หลายการคำนวณเพิ่มขึ้นตาม
N_tokens^3ทำให้ความต้องการ memory และ compute สูงขึ้น - เพื่อประสิทธิภาพ จึงทำ random crop กับโปรตีน
- เนื่องจากต้องโมเดล interaction ระหว่างหลาย chain crop จึงต้องรวม chain หลายเส้นไว้ด้วยกัน
- ใช้วิธี cropping สามแบบ
- contiguous cropping: เลือก sequence กรดอะมิโนที่ต่อเนื่องกันจากแต่ละ chain
- spatial cropping: เลือกกรดอะมิโนโดยอิงจากระยะถึงอะตอมอ้างอิง
- spatial interface cropping: เลือกโดยอิงจากระยะถึงอะตอมของ binding interface
- โมเดลที่ฝึกด้วย random crop 384 ก็สามารถใช้กับ sequence ที่ยาวกว่าได้ แต่เพื่อเพิ่มความสามารถในการจัดการ sequence ที่ยาวขึ้น จึงทำ fine-tuning ซ้ำด้วย sequence length ที่ใหญ่ขึ้น
- ตัวโมเดลเองไม่มีข้อจำกัดแบบชัดเจนต่อความยาวลำดับอินพุต แต่หลายการคำนวณเพิ่มขึ้นตาม
-
Clashing และ batch size
- loss ของ AF3 ไม่มี clash penalty สำหรับอะตอมที่ overlap กัน
- diffusion-based structure module ในทางทฤษฎีสามารถคาดการณ์ให้อะตอมสองตัวอยู่ตำแหน่งเดียวกันได้ แต่หลังการฝึก ปัญหานี้มีน้อย
- ในการ ranking โครงสร้างที่สร้างขึ้น จะใช้ clashing penalty
- diffusion process ดูซับซ้อน แต่มีต้นทุนการคำนวณต่ำกว่า trunk
- เพื่อเพิ่มประสิทธิภาพการฝึก จะขยาย batch size หลังผ่าน trunk
- แต่ละ input structure จะผ่าน embedding และ trunk หนึ่งครั้ง จากนั้น data-augmented structure อิสระ 48 รายการจะถูกฝึกแบบขนาน
การออกแบบ AF3 จากมุมมอง ML
-
โครงสร้างที่คล้ายกับ Retrieval-Augmented Generation
- การค้นหา MSA และ template ของ AF3 มีลักษณะคล้ายกับ RAG ของโมเดลภาษา
- ในวงการ AlphaFold วิธีใช้ structure template ถูกเรียกว่า homology modeling มานานก่อนคำว่า RAG
- AF3 ลดสัดส่วนการประมวลผล MSA ลงเมื่อเทียบกับ AF2 แต่ยังคงรวม MSA และ template ไว้
- โมเดลทำนายโปรตีนบางตัว เช่น ESMFold ตัด retrieval ออกและใช้ fully parametric inference
-
Pair-Bias Attention
- Pair-Bias Attention ซึ่งเป็นองค์ประกอบหลักของ AF2 ถูกนำมาใช้กว้างขึ้นใน AF3
- query, key และ value มาจาก source เดียวกัน แต่ใน attention map จะมีการเพิ่ม bias term ที่มาจาก source อื่น
- นี่เป็นวิธีแชร์ข้อมูลที่เบากว่า full cross-attention
- เนื่องจาก pair representation มีความคล้ายกับ attention map โดยธรรมชาติ โครงสร้างนี้จึงอาจเหมาะกับการทำโมเดลโปรตีน
-
การลดบทบาทของ self-supervised training
- โมเดลตระกูล ESM แสดงจุดแข็งในแนวทางที่ใช้ self-supervised pre-training เพื่อทดแทน MSA embedding
- ใน AF2 มี task เพิ่มเติมสำหรับทำนาย masked token ของ MSA แต่ใน AF3 ถูกนำออกไป
- AF3 ลด compute สำหรับการประมวลผล MSA และไม่ใช้ self-supervised language modeling pre-training สำหรับ MSA
- เหตุผลที่เป็นไปได้คือ massive pre-training อาจไม่มีประสิทธิภาพในแง่การใช้ compute, หรือ MSA module ขนาดเล็กอาจดีกว่า pre-trained embedding, หรือโครงสร้าง hybrid atom-token ที่ผสมกรดอะมิโน·DNA/RNA·ligand อาจไม่เข้ากับการผสานกับ pre-trained embedding
-
การผสมระหว่าง Classification และ Regression
- AF3 ใช้ MSE ร่วมกับ binned classification loss เหมือน AF2
- จุดเด่นของ classification loss คือแม้จะทาย distogram bin ผิดไปเพียงหนึ่ง bin ก็จะไม่ได้ credit เช่นเดียวกับกรณีที่ทายผิดไปไกล
- เหตุผลของการเลือกออกแบบเช่นนี้ยังไม่ชัดเจน แต่เป็นไปได้ว่า gradient มีเสถียรภาพมากกว่า MSE loss หลายตัว
-
องค์ประกอบที่คล้าย recurrent architecture
- AF3 มีองค์ประกอบหลายอย่างที่ทำให้นึกถึง recurrent network มากกว่า transformer ทั่วไป
- gating ควบคุมการไหลของข้อมูลใน residual stream และคล้ายกับ gate ของ LSTM หรือ GRU
- recycling และ diffusion ใช้ weight เดิมซ้ำ ๆ เพื่อค่อย ๆ ปรับปรุงการทำนาย
- คล้ายกับ adaptive compute time การอัปเดตซ้ำ ๆ เกี่ยวข้องกับโครงสร้างที่สามารถใช้การประมวลผลมากขึ้นกับอินพุตที่ยากกว่า
- ใน ablation ของ AF2 พบว่า recycling มีความสำคัญ แต่มีการอภิปรายเกี่ยวกับความสำคัญของ gating ไม่มากนัก
ยังไม่มีความคิดเห็น