• AlphaFold3 มีเป้าหมายที่จะ ทำนายจากลำดับเท่านั้น สำหรับคอมเพล็กซ์ที่มีโปรตีน กรดนิวคลีอิก และโมเลกุลขนาดเล็กอยู่ร่วมกัน โดยไปไกลกว่าการทำนายโปรตีนเดี่ยว ทำให้การแทนอินพุตและการทำ tokenization ซับซ้อนกว่า AF2 มาก
  • อินพุตแบ่งเป็น single/pair representation ระดับโทเคน, representation ระดับอะตอม, MSA และเทมเพลต โดยกรดอะมิโนมาตรฐานและนิวคลีโอไทด์มาตรฐานถือเป็น 1 โทเคน ส่วน residue ที่ไม่เป็นมาตรฐานและโมเลกุลอื่น ๆ จัดการเป็น 1 โทเคนต่ออะตอม
  • trunk สำหรับเรียนรู้ representation ปรับปรุง single representation s และ pair representation z ซ้ำ ๆ ผ่านโมดูลเทมเพลต, โมดูล MSA และ Pairformer ด้วย pair-bias attention, triangle operation และ recycling
  • การทำนายโครงสร้างใช้ conditional diffusion model กับพิกัดอะตอม แทน Invariant Point Attention ของ AF2 และสร้างการอัปเดตพิกัดของทุกอะตอมผ่าน rotation/translation augmentation และ denoising
  • การฝึกผสาน distogram, diffusion และ confidence loss และเรียนรู้ซ้ำแม้แต่ representation แบบ unfolded ของบริเวณความเชื่อมั่นต่ำผ่าน cross-distillation ที่ใช้ผลลัพธ์จาก AF2 และ AF-Multimer

ขอบเขตอินพุตและ pipeline โดยรวมของ AlphaFold3

  • เป้าหมายของ AlphaFold3 ไม่ได้หยุดอยู่แค่การทำนายลำดับโปรตีนเดี่ยวแบบ AF2 หรือจัดการเฉพาะโปรตีนคอมเพล็กซ์แบบ AF-Multimer แต่คือการ ทำนายจากลำดับเท่านั้น สำหรับโครงสร้างที่โปรตีนจับกับโปรตีนอื่น กรดนิวคลีอิก หรือโมเลกุลขนาดเล็กได้ตามตัวเลือก
  • ความหมายของ “โทเคน” เปลี่ยนไปตามชนิดของอินพุต
    • โปรตีน: กรดอะมิโนมาตรฐาน 1 ตัวเป็น 1 โทเคน
    • DNA/RNA: นิวคลีโอไทด์มาตรฐาน 1 ตัวเป็น 1 โทเคน
    • กรดอะมิโน/นิวคลีโอไทด์ที่ไม่เป็นมาตรฐาน: อะตอม 1 อะตอมเป็น 1 โทเคน
    • โมเลกุลอื่น ๆ: อะตอม 1 อะตอมเป็น 1 โทเคน
  • โปรตีนที่ประกอบด้วยกรดอะมิโนมาตรฐาน 35 ตัวอาจมีอะตอมจริงมากกว่า 600 อะตอม แต่ถูกแทนด้วย 35 โทเคน ขณะที่ ligand ที่มี 35 อะตอมจะถูกแทนด้วย 35 โทเคน
  • โมเดลประกอบด้วยสามขั้นตอนหลัก
    • Input Preparation: แปลงลำดับที่ผู้ใช้ป้อน รวมถึงลำดับและโครงสร้างที่เกี่ยวข้องซึ่งค้นพบ ให้เป็น tensor เชิงตัวเลข
    • Representation Learning: อัปเดต single representation และ pair representation ด้วย attention หลายรูปแบบ
    • Structure Prediction: ทำนายโครงสร้างด้วย conditional diffusion
  • โปรตีนคอมเพล็กซ์ถูกเก็บไว้เป็นหลักใน representation สองแบบ
    • single representation: แทนตัวโทเคนทั้งหมดในคอมเพล็กซ์
    • pair representation: แทนความสัมพันธ์ เช่น ระยะทางและปฏิสัมพันธ์แฝง ระหว่างคู่โทเคนทั้งหมด
  • มิติช่องหลักคือ c_z=128, c_m=64, c_atom=128, c_atompair=16, c_token=768, c_s=384

การเตรียมอินพุต: กระบวนการแปลงลำดับเป็น tensor 6 ตัว

  • อินพุตที่ผู้ใช้ให้มาจะถูกแปลงเป็น tensor 6 ตัวสำหรับป้อนเข้า model trunk
    • s: token-level single representation
    • z: token-level pair representation
    • q: atom-level single representation
    • p: atom-level pair representation
    • m: MSA representation
    • t: template representation
  • การค้นหา MSA และเทมเพลต

    • AF3 ค้นหาลำดับที่คล้ายกันสำหรับลำดับโปรตีนและ RNA แล้วจัดเป็น MSA และรวมโครงสร้างที่เกี่ยวข้องไว้เป็น template
    • MSA จัดเรียงลำดับโปรตีนที่คล้ายกันซึ่งพบในหลายสปีชีส์ เพื่อให้โมเดลเห็นรูปแบบการอนุรักษ์ของตำแหน่งหนึ่ง ๆ และสหสัมพันธ์ของการเปลี่ยนแปลงระหว่างตำแหน่งต่าง ๆ
    • โครงสร้างที่ทราบของโปรตีนที่คล้ายกันถูกใช้เพื่อประมาณโครงสร้างของโปรตีน query เช่นเดียวกับ homology modeling
    • การค้นหาไม่รวมการเรียนรู้ และใช้วิธีที่อิง HMM
    • ใช้ jackhmmer, HHBlits, nhmmer เพื่อค้นหาฐานข้อมูลโปรตีนและ RNA หลายชุด และใช้ hmmsearch เพื่อค้นหาลำดับที่คล้ายกันใน Protein Data Bank
    • ขนาด MSA ถูกจำกัดไว้ที่ N_MSA < 2^14 เนื่องจากความซับซ้อนในการคำนวณ
    • ในแต่ละ protein chain จะเลือกโครงสร้างคุณภาพสูง และสุ่มตัวอย่างได้สูงสุด 4 รายการเป็น template
    • องค์ประกอบการค้นหาที่เพิ่มใหม่เมื่อเทียบกับ AF-Multimer คือ รวมลำดับ RNA เป็นเป้าหมายการค้นหาด้วย
  • วิธีแทนเทมเพลต

    • คำนวณระยะทางแบบ Euclidean ระหว่างคู่โทเคนแต่ละคู่จากโครงสร้าง 3D ของ template
    • โทเคนที่มีหลายอะตอมใช้ “center atom” เป็นตัวแทน
      • กรดอะมิโน: อะตอม
      • นิวคลีโอไทด์มาตรฐาน: อะตอม C1'
    • ค่าระยะทางไม่ได้เป็นค่าต่อเนื่อง แต่ถูกทำให้เป็นแบบไม่ต่อเนื่องด้วย distogram
      • 38 bin ตั้งแต่ 3.15Å ถึง 50.75Å
      • bin เพิ่มเติมอีก 1 bin สำหรับระยะทางที่มากกว่านั้น
    • ใน distogram มีการเพิ่มข้อมูล chain, สถานะว่าโทเคนนั้นถูก resolved ใน crystal structure หรือไม่ และข้อมูล local distance ภายในกรดอะมิโนแต่ละตัว
    • template matrix ถูก masking ให้ดูเฉพาะระยะทางภายใน chain เดียวกัน และไม่ได้พยายามได้ข้อมูล inter-chain interaction จากการเลือก template

การแทนค่าระดับอะตอมและ Atom Transformer

  • reference conformer และการแทนค่าระดับอะตอม

    • เพื่อสร้าง single representation ระดับอะตอม q จะคำนวณ reference conformer สำหรับกรดอะมิโน นิวคลีโอไทด์ และ ligand แต่ละตัว
    • conformer คือการจัดเรียงอะตอมแบบ 3D ของโมเลกุลที่สร้างขึ้นโดยการสุ่มตัวอย่างการหมุนรอบพันธะเดี่ยว
    • กรดอะมิโนมาตรฐานใช้ conformer พลังงานต่ำที่หาได้จาก lookup ส่วนโมเลกุลขนาดเล็กสร้าง 3D conformer ด้วย RDKit’s ETKDGv3
    • นำตำแหน่งสัมพัทธ์ของ conformer, ประจุอะตอม, เลขอะตอม, identifier ฯลฯ มารวมกันเพื่อสร้าง atom-level single representation c
    • ใช้ c เพื่อ initialize atom-level pair representation p และใช้ mask v เพื่อให้มีเฉพาะระยะห่างระหว่างอะตอมที่คำนวณจาก reference conformer
    • q เริ่มจากสำเนาของ c แล้วจึงถูกอัปเดตใน Atom Transformer
  • บทบาทของ Atom Transformer

    • Atom Transformer เป็นโมดูลที่ทำ attention ระดับอะตอม โดยใช้ p และ representation เดิม c เพื่ออัปเดต q
    • c จะไม่ถูกอัปเดต และถูกใช้คล้าย residual connection ที่ชี้กลับไปยัง representation ตั้งต้น
    • โครงสร้างพื้นฐานคล้าย transformer โดยมี LayerNorm, attention และ MLP transition แต่แต่ละขั้นตอนถูกปรับด้วยอินพุตเพิ่มเติมอย่าง c และ p
  • Adaptive LayerNorm

    • Adaptive LayerNorm ไม่ได้เรียนรู้ gamma, beta แบบคงที่ แต่สร้าง gamma, beta จากอินพุตเสริม
    • ใน Atom Transformer สิ่งที่ถูก rescale คือ q และพารามิเตอร์สำหรับ rescale ถูกทำนายจากอินพุตเสริม c
  • Attention with Pair Bias

    • Atom-level attention with pair bias เป็นส่วนขยายของ self-attention
    • query, key, value ทั้งหมดมาจาก single representation q แต่หลังจากทำ query-key dot product แล้ว จะบวก linear projection ของ pair representation p เข้าไปเป็น bias
    • ข้อมูลไหลจาก pair representation ไปยัง q แต่ในขั้นตอนนี้จะไม่ใช้ข้อมูลของ q เพื่ออัปเดต p
    • gate ที่สร้างโดยนำ projection เพิ่มเติมผ่าน sigmoid จะถูกคูณกับผลลัพธ์ของ attention เพื่อควบคุมว่าจะเหลือข้อมูลใดไว้ใน residual stream
    • จำนวนอะตอมอาจมากกว่าจำนวนโทเคนมาก จึงใช้ Sequence-local atom attention แทน full attention
    • local group ขนาด 32 อะตอมสามารถ attend ไปยังอะตอมอื่นได้ 128 อะตอม
  • Conditioned Gating และ Transition

    • Conditioned Gating ใช้ gate ที่สร้างจาก atom-level single matrix เดิม c กับข้อมูล
    • Conditioned Transition เทียบได้กับ MLP ของ transformer และถูกเรียกว่า conditioned เพราะ Adaptive LayerNorm และ Conditional Gating ขึ้นกับ c
    • AF3 ใช้ SwiGLU ใน transition block แทน ReLU
    • transition แบบ ReLU ของ AF2 มีโครงสร้าง up-projection 4 เท่า, ReLU, down-projection
    • SwiGLU ของ AF3 จะใช้ nonlinearity แบบ swish กับหนึ่งใน up-projection สองตัว จากนั้นนำมาคูณกันแล้ว down-project

การรวมการแทนค่าอะตอมเป็นการแทนค่าโทเคน

  • ขั้นตอน representation learning หลังจากนี้ทำงานในระดับ token-level จึงรวม atom-level representation ให้เป็น token-level representation
  • หลังจาก projection atom-level representation ไปยังมิติที่ใหญ่ขึ้นแล้ว จะหาค่าเฉลี่ยของอะตอมที่อยู่ในโทเคนเดียวกัน
  • การรวมแบบค่าเฉลี่ยนี้ใช้เมื่ออะตอมหลายตัวเชื่อมกับโทเคนเดียว เช่น กรดอะมิโนและนิวคลีโอไทด์มาตรฐาน ส่วนอินพุตที่เป็น 1 โทเคนต่อ 1 อะตอมจะคงไว้ตามเดิม
  • token-level single input ยังรวมสถิติที่ได้จาก MSA ด้วย
    • ชนิดกรดอะมิโน
    • การกระจายของกรดอะมิโนใน MSA ณ ตำแหน่งนั้น
    • deletion mean ของโทเคนนั้น
  • สำหรับโทเคนที่ไม่มี MSA เช่น อะตอมของ ligand ค่าเหล่านี้จะเป็น 0
  • s_inputs ที่สร้างขึ้นเช่นนี้จะผ่าน projection กลายเป็น s_init และถูกอัปเดตในขั้นตอน representation learning
  • pair representation z_init เป็นเทนเซอร์ 3 มิติที่เก็บความสัมพันธ์ของ token pair แต่ละคู่ โดย z_i,j แต่ละตัวเป็นเวกเตอร์มิติ c_z=128
  • ในการ initialize z_i,j จะบวก projection ของ s_i, s_j, relative positional encoding และข้อมูล bond ระหว่างโทเคนที่ผู้ใช้ระบุ

การเรียนรู้ representation: Template, MSA, Pairformer

  • representation learning คือ trunk ซึ่งกินการคำนวณส่วนใหญ่ของโมเดล และมีเป้าหมายเพื่อปรับปรุง token-level single representation s และ pair representation z
  • single sequence representation ไม่ได้หมายถึงลำดับโปรตีนเดียวเท่านั้น แต่หมายถึง sequence ที่นำอะตอมหรือโทเคนทั้งหมดในโครงสร้างมาต่อกัน
  • Template Module

    • template แต่ละตัวผ่าน linear projection แล้วนำไปบวกกับ linear projection ของ pair representation z
    • matrix ที่รวมแล้วจะผ่าน Pairformer Stack
    • ผลลัพธ์จาก template หลายตัวจะถูกนำมาเฉลี่ย แล้วผ่าน linear layer อีกครั้ง
    • linear layer สุดท้ายใช้ ReLU และเป็นหนึ่งในไม่กี่ตำแหน่งใน AF3 ที่ใช้ ReLU เป็น nonlinearity
  • MSA Module

    • MSA Module คล้ายกับ Evoformer ของ AF2 มาก และปรับปรุง MSA representation m กับ pair representation z พร้อมกัน
    • ไม่ได้ใช้ MSA row ทั้งหมด แต่ทำ subsampling แล้วบวก projection ของ single representation เข้าไปใน MSA
    • Outer Product Mean เป็นการดำเนินการที่ใส่ข้อมูล MSA เข้าไปใน pair representation
      • สำหรับ token index i,j แต่ละคู่ จะคำนวณ outer product ของ m_s,i และ m_s,j สำหรับ evolutionary sequence ทั้งหมด
      • นำไปเฉลี่ยตลอด sequence ทั้งหมด flatten แล้ว projection จากนั้นบวกเข้าไปใน z_i,j
      • เป็นจุดเดียวในโมเดลที่ข้อมูลถูกแชร์กันระหว่าง evolutionary sequence
    • Row-wise gated self-attention using only pair bias ใช้ pair representation เพื่ออัปเดต MSA
      • แทนที่จะสร้าง attention score ด้วย query และ key จะ projection pair representation z เป็น matrix แล้วใช้เป็น attention score ระหว่างโทเคน
      • เนื่องจากใช้แยกกันกับ MSA row แต่ละแถว ในขั้นตอนนี้จึงไม่มีการแชร์ข้อมูลระหว่าง evolutionary sequence
    • ตอนท้ายของ MSA module จะอัปเดต pair representation อีกครั้งด้วย triangle update และ triangle attention

Pairformer และการดำเนินการแบบ triangle

  • หลังจากอัปเดต z ด้วย Template และ MSA แล้ว จะไม่ใช้ template และ MSA อีกต่อไป โดยจะป้อนเฉพาะ s และ z เข้าไปยัง Pairformer
  • Pairformer สร้าง s_trunk และ z_trunk ขั้นสุดท้ายผ่านการทำซ้ำ block 48 ชุด
  • สัญชาตญาณของการดำเนินการแบบ triangle

    • triangle update และ triangle attention เป็นโครงสร้างที่พยายามสะท้อนสัญชาตญาณของ อสมการสามเหลี่ยม เข้าไปในโมเดล
    • แม้ z_i,j ของ pair tensor จะไม่ใช่ระยะทางทางกายภาพโดยตรง แต่เนื่องจากมันบรรจุความสัมพันธ์ระหว่าง token i และ j จึงอัปเดตให้ความสัมพันธ์ทั้งสามของ i-j, j-k, i-k สอดคล้องกัน
    • อสมการสามเหลี่ยมไม่ได้ถูกบังคับโดยตรงภายในโมเดล แต่ถูกชักนำด้วยวิธีดู triplet (i,j,k) ทั้งหมดแล้วอัปเดต z_i,j
    • z สามารถมองได้คล้าย directed adjacency matrix จึงแยกประมวลผลทิศทางของ outgoing edge และ incoming edge
  • Triangle Updates

    • ใน outgoing update จะอัปเดต z_i,j แต่ละตัวโดยใช้องค์ประกอบอื่นใน row เดียวกันคือ z_i,k และ edge ที่สาม z_j,k
    • ในเชิง implementation จะสร้าง projection สามตัวของ z ได้แก่ a, b, g จากนั้นนำ row i และ row j มาทำ element-wise multiplication แล้วรวมตาม k ก่อนใช้ gate g
    • incoming update เป็นรูปแบบที่สลับ row กับ column โดยอัปเดต z_i,j ผ่านองค์ประกอบอื่นใน column เดียวกันคือ z_k,j และ z_k,i
  • Triangle Attention

    • triangle attention เป็นรูปแบบที่เพิ่มหลักการ triangle เข้าไปใน axial attention ซึ่งใช้ attention แยกกันบน row และ column ของ 2D matrix
    • ในกรณี “starting node” จะเพิ่ม z_j,k เป็น bias ให้กับการเปรียบเทียบ query-key ของ z_i,j และ z_i,k
    • ในกรณี “ending node” จะทำงานตาม column และใช้ z_k,j เป็น bias ให้กับ attention score ของ z_i,j และ z_k,i
  • Single Attention with Pair Bias

    • หลังจาก triangle step และ transition block แล้ว single representation s จะถูกอัปเดตด้วย single attention with pair bias ที่ใช้ updated pair representation z
    • เนื่องจากทำงานในระดับ token จึงใช้ full attention ไม่ใช่ block-wise sparse attention ที่ใช้ในระดับ atom

การทำนายโครงสร้าง: denoising พิกัดอะตอมด้วย diffusion

  • วิธีพื้นฐานของโมเดล diffusion

    • AF3 ทำการทำนายโครงสร้างขั้นสุดท้ายด้วย atom-level diffusion
    • diffusion model เรียนรู้โดยเพิ่ม random noise เข้าไปในข้อมูลจริงทีละขั้น และให้โมเดลทำนายว่า noise แบบใดถูกเพิ่มเข้าไป
    • ในช่วง inference จะเริ่มจาก random noise อย่างสมบูรณ์ แล้วลบ noise ที่โมเดลทำนายในแต่ละ step เพื่อสร้าง datapoint ที่ผ่านการ denoise แล้ว
    • conditional diffusion รับ current noisy generation, การแทนค่า timestep ปัจจุบัน และ condition vector เป็นอินพุต เพื่อสร้างผลลัพธ์ที่ตรงตามเงื่อนไข
    • ใน AF3 เป้าหมายของ denoising คือ matrix x ที่บรรจุพิกัด x,y,z ของอะตอมทั้งหมด
  • การเสริมข้อมูลด้วยการหมุนและเลื่อนแทน IPA ของ AF2

    • AF3 ไม่ใช้ Invariant Point Attention ของ AF2 แต่จะสุ่มหมุนและเลื่อนคอมเพล็กซ์ทั้งหมดที่กำลังทำนายในแต่ละ timestep
    • การเสริมข้อมูลนี้ทำให้โมเดลเรียนรู้ว่าไม่ว่าจะหมุนหรือเลื่อนแบบใดก็ยังเป็นโครงสร้างเดียวกันที่ใช้ได้ และเป็นแนวทางที่ง่ายกว่า IPA ของ AF2
    • การหมุนจะถูกใช้โดยมีค่าเฉลี่ยของพิกัดอะตอมทั้งหมดใน generation ปัจจุบันเป็นศูนย์กลาง และ translation จะถูกสุ่มจาก Gaussian N(0,1) ในแต่ละมิติ
    • ยังมีการเพิ่ม noise เล็กน้อยให้กับพิกัด เพื่อชักนำให้เกิด generation ที่หลากหลายขึ้น
    • ในช่วง inference สามารถให้คะแนน generation หลายชุดด้วย confidence head แล้วส่งคืน generation ที่ได้คะแนนสูงสุดได้
  • สี่ขั้นตอนของ Diffusion Module

    • แต่ละ denoising step ใช้ conditioning representation หลายชุด
      • เอาต์พุต trunk s_trunk, z_trunk
      • representation เริ่มต้น s_inputs, c_inputs ที่สร้างจาก input embedder
    • กระบวนการ diffusion ประกอบด้วยสี่ขั้นตอนที่สลับไปมาระหว่างพื้นที่ token และ atom
        1. เตรียม token-level conditioning tensor
        1. เตรียม atom-level conditioning tensor, ใช้ Atom Transformer แล้วรวมกลับเป็น token-level
        1. ใช้ token-level attention
        1. ทำนาย noise update รายอะตอมด้วย atom-level attention
    • ใน token-level conditioning จะรวม z_trunk กับ relative positional encoding แล้วส่งผ่าน transition block
    • ใน single representation จะรวม s_inputs กับ s_trunk และเพิ่ม Fourier embedding ตาม diffusion timestep
    • ในขั้น atom-level จะอัปเดต c, p เริ่มต้นด้วย token-level representation ปัจจุบัน และสเกลพิกัดปัจจุบัน x ด้วย data variance เพื่อสร้าง dimensionless coordinate r
    • ในขั้น atom-level สุดท้าย linear layer จะ map q ไปยัง R^3 เพื่อสร้าง coordinate update r_update ของอะตอมทั้งหมด
    • update จะถูก rescale เป็น x_update โดยคำนึงถึง data variance และ noise schedule แล้วนำไปใช้กับพิกัดปัจจุบัน x_l

ฟังก์ชัน loss และ confidence head

  • loss ทั้งหมดคือผลรวมถ่วงน้ำหนักของสามพจน์

L_loss = L_distogram * α_distogram + L_diffusion * α_diffusion + L_confidence * α_confidence

  • L_distogram

    • L_distogram ประเมินความแม่นยำของ distogram ที่คาดการณ์ในระดับ token
    • เมื่อสร้างพิกัด token จากพิกัดอะตอม จะใช้พิกัดของ center atom ของแต่ละ token
    • ระยะทางใน distogram ถูกจัดการเป็น categorical value และเปรียบเทียบ distogram ที่คาดการณ์กับ distogram จริงด้วย cross entropy
  • L_diffusion

    • L_diffusion เป็นผลรวมถ่วงน้ำหนักของหลายพจน์ที่ใช้กับตำแหน่ง atom
    • L_MSE คำนวณ mean squared error ระหว่าง position สำหรับทุกอะตอม ไม่ใช่เฉพาะ center atom และอะตอมของ DNA, RNA, ligand จะถูก upweight
    • L_bond เป็นพจน์ MSE เพิ่มเติมเพื่อเพิ่มความแม่นยำของ bond length ของ atom pair ที่อยู่ใน protein-ligand bond
    • ในช่วง training stage แรก α_bond=0 จึงถูกนำมาใช้ในภายหลัง
    • L_smooth_LDDT เป็น loss ที่ทำให้ local distance accuracy นุ่มนวลและหาอนุพันธ์ได้
      • ใช้ threshold สี่ค่า ได้แก่ 4Å, 2Å, 1Å, 0.5Å
      • atom pair ของ nucleotide จะถูกละเว้นหากอยู่ไกลกว่า 30Å
      • atom pair ของ protein หรือ ligand จะถูกละเว้นหากอยู่ไกลกว่า 15Å
  • L_confidence

    • L_confidence ฝึกให้โมเดลประมาณความแม่นยำของการคาดการณ์ของตัวเอง แทนที่จะเพิ่มความแม่นยำของโครงสร้างโดยตรง
    • ประกอบด้วย loss ที่สอดคล้องกับ confidence metric สี่ประเภท
      • pLDDT: local distance accuracy สำหรับอะตอมที่อยู่ใกล้กัน
      • PAE: predicted alignment error ของ token pair
      • PDE: predicted distance error ระหว่าง token pair
      • experimentally resolved prediction: คาดการณ์ว่าแต่ละอะตอมถูก resolved ในโครงสร้างจากการทดลองหรือไม่
    • แม้โครงสร้างที่คาดการณ์จะไม่แม่นยำจนทำให้ PAE สูง แต่หากโมเดลคาดการณ์ว่า PAE สูงเช่นกัน PAE loss ดังกล่าวก็อาจต่ำได้
    • confidence prediction ถูกสร้างขึ้นในขั้นกลางของ diffusion
    • gradient ของ confidence loss จะอัปเดตเฉพาะ confidence prediction head และไม่ส่งผลต่อส่วนอื่นของโมเดล

เทคนิคการเรียนรู้เพิ่มเติมและการเพิ่มประสิทธิภาพ

  • Recycling

    • AF3 ใช้ weight recycling เช่นเดียวกับ AF2
    • แทนที่จะทำให้โมเดลลึกขึ้น จะนำ weight เดิมกลับมาใช้ซ้ำหลายครั้งเพื่อปรับปรุง representation ทีละขั้น
    • diffusion ก็ใช้ข้อมูล timestep ใน inference และนำ weight เดิมกลับมาใช้ซ้ำในทุก timestep จึงมี recycling อยู่ในตัว
  • Cross-distillation

    • AF3 ใช้ทั้ง synthetic training data ที่สร้างขึ้นเอง และ synthetic data ที่สร้างโดย AF2 กับ AF-Multimer
    • หลังเปลี่ยนมาใช้ generation แบบ diffusion-based เกิดปัญหาว่ารูปแบบ “spaghetti” ที่เคยช่วยแยกแยะบริเวณความเชื่อมั่นต่ำและไร้ระเบียบใน AF2 ทางภาพหายไป
    • โดยรวม generation ของ AF2 และ AF-Multimer เข้าใน training data ของ AF3 ทำให้ AF3 เรียนรู้วิธีที่ AF2 แสดง unfolded region ในบริเวณที่ไม่มั่นใจ
    • ใน distillation dataset จะลบกรดนิวคลีอิกและโมเลกุลขนาดเล็กที่ AF2 และ AF-Multimer ไม่สามารถจัดการได้ออก
    • หลังจากโมเดลก่อนหน้าสร้างโครงสร้างที่คาดการณ์แล้ว เมื่อนำไป alignment กับต้นฉบับ จะเพิ่มโมเลกุลที่เคยลบออกกลับเข้ามา
    • หากโมเลกุลที่เพิ่มกลับเข้ามาทำให้เกิด atom clash จะตัดโครงสร้างทั้งหมดออก เพื่อหลีกเลี่ยงการฝึกให้โมเดลยอมรับ clash
  • Cropping และ training stage

    • ตัวโมเดลเองไม่มีข้อจำกัดแบบชัดเจนต่อความยาวลำดับอินพุต แต่หลายการคำนวณเพิ่มขึ้นตาม N_tokens^3 ทำให้ความต้องการ memory และ compute สูงขึ้น
    • เพื่อประสิทธิภาพ จึงทำ random crop กับโปรตีน
    • เนื่องจากต้องโมเดล interaction ระหว่างหลาย chain crop จึงต้องรวม chain หลายเส้นไว้ด้วยกัน
    • ใช้วิธี cropping สามแบบ
      • contiguous cropping: เลือก sequence กรดอะมิโนที่ต่อเนื่องกันจากแต่ละ chain
      • spatial cropping: เลือกกรดอะมิโนโดยอิงจากระยะถึงอะตอมอ้างอิง
      • spatial interface cropping: เลือกโดยอิงจากระยะถึงอะตอมของ binding interface
    • โมเดลที่ฝึกด้วย random crop 384 ก็สามารถใช้กับ sequence ที่ยาวกว่าได้ แต่เพื่อเพิ่มความสามารถในการจัดการ sequence ที่ยาวขึ้น จึงทำ fine-tuning ซ้ำด้วย sequence length ที่ใหญ่ขึ้น
  • Clashing และ batch size

    • loss ของ AF3 ไม่มี clash penalty สำหรับอะตอมที่ overlap กัน
    • diffusion-based structure module ในทางทฤษฎีสามารถคาดการณ์ให้อะตอมสองตัวอยู่ตำแหน่งเดียวกันได้ แต่หลังการฝึก ปัญหานี้มีน้อย
    • ในการ ranking โครงสร้างที่สร้างขึ้น จะใช้ clashing penalty
    • diffusion process ดูซับซ้อน แต่มีต้นทุนการคำนวณต่ำกว่า trunk
    • เพื่อเพิ่มประสิทธิภาพการฝึก จะขยาย batch size หลังผ่าน trunk
    • แต่ละ input structure จะผ่าน embedding และ trunk หนึ่งครั้ง จากนั้น data-augmented structure อิสระ 48 รายการจะถูกฝึกแบบขนาน

การออกแบบ AF3 จากมุมมอง ML

  • โครงสร้างที่คล้ายกับ Retrieval-Augmented Generation

    • การค้นหา MSA และ template ของ AF3 มีลักษณะคล้ายกับ RAG ของโมเดลภาษา
    • ในวงการ AlphaFold วิธีใช้ structure template ถูกเรียกว่า homology modeling มานานก่อนคำว่า RAG
    • AF3 ลดสัดส่วนการประมวลผล MSA ลงเมื่อเทียบกับ AF2 แต่ยังคงรวม MSA และ template ไว้
    • โมเดลทำนายโปรตีนบางตัว เช่น ESMFold ตัด retrieval ออกและใช้ fully parametric inference
  • Pair-Bias Attention

    • Pair-Bias Attention ซึ่งเป็นองค์ประกอบหลักของ AF2 ถูกนำมาใช้กว้างขึ้นใน AF3
    • query, key และ value มาจาก source เดียวกัน แต่ใน attention map จะมีการเพิ่ม bias term ที่มาจาก source อื่น
    • นี่เป็นวิธีแชร์ข้อมูลที่เบากว่า full cross-attention
    • เนื่องจาก pair representation มีความคล้ายกับ attention map โดยธรรมชาติ โครงสร้างนี้จึงอาจเหมาะกับการทำโมเดลโปรตีน
  • การลดบทบาทของ self-supervised training

    • โมเดลตระกูล ESM แสดงจุดแข็งในแนวทางที่ใช้ self-supervised pre-training เพื่อทดแทน MSA embedding
    • ใน AF2 มี task เพิ่มเติมสำหรับทำนาย masked token ของ MSA แต่ใน AF3 ถูกนำออกไป
    • AF3 ลด compute สำหรับการประมวลผล MSA และไม่ใช้ self-supervised language modeling pre-training สำหรับ MSA
    • เหตุผลที่เป็นไปได้คือ massive pre-training อาจไม่มีประสิทธิภาพในแง่การใช้ compute, หรือ MSA module ขนาดเล็กอาจดีกว่า pre-trained embedding, หรือโครงสร้าง hybrid atom-token ที่ผสมกรดอะมิโน·DNA/RNA·ligand อาจไม่เข้ากับการผสานกับ pre-trained embedding
  • การผสมระหว่าง Classification และ Regression

    • AF3 ใช้ MSE ร่วมกับ binned classification loss เหมือน AF2
    • จุดเด่นของ classification loss คือแม้จะทาย distogram bin ผิดไปเพียงหนึ่ง bin ก็จะไม่ได้ credit เช่นเดียวกับกรณีที่ทายผิดไปไกล
    • เหตุผลของการเลือกออกแบบเช่นนี้ยังไม่ชัดเจน แต่เป็นไปได้ว่า gradient มีเสถียรภาพมากกว่า MSE loss หลายตัว
  • องค์ประกอบที่คล้าย recurrent architecture

    • AF3 มีองค์ประกอบหลายอย่างที่ทำให้นึกถึง recurrent network มากกว่า transformer ทั่วไป
    • gating ควบคุมการไหลของข้อมูลใน residual stream และคล้ายกับ gate ของ LSTM หรือ GRU
    • recycling และ diffusion ใช้ weight เดิมซ้ำ ๆ เพื่อค่อย ๆ ปรับปรุงการทำนาย
    • คล้ายกับ adaptive compute time การอัปเดตซ้ำ ๆ เกี่ยวข้องกับโครงสร้างที่สามารถใช้การประมวลผลมากขึ้นกับอินพุตที่ยากกว่า
    • ใน ablation ของ AF2 พบว่า recycling มีความสำคัญ แต่มีการอภิปรายเกี่ยวกับความสำคัญของ gating ไม่มากนัก

ยังไม่มีความคิดเห็น

ยังไม่มีความคิดเห็น