文章目录
1 要点
代码:https://github.com/Josie-xufan/LaDM3IL
方法:
- 利用实例级多示例学习 (MIL) 框架来减少大规模MIL场景下的计算开销,并处理免疫库分类问题下的大存储容量挑战;
- 模型包括:
- 特征提取器: 使用门控注意力机制和张量融合的多模态融合模块,整合氨基酸序列和VDJ基因片段的信息;
- 标签消歧模块,用于降低错误监督的影响;
- 聚合模块:整合各个受体的预测及其相应频率来生成免疫库水平的预测;
- 实验:
- 数据集:巨细胞病毒 (Cytomegalovirus, CMV) 和癌症数据集;
- 地址:https://clients.adaptivebiotech.com/pub/Emerson-2017-NatGen
背景:
注:该方法与免疫库分类这一背景很相关,可以重点关注其如何利用多模态MIL处理问题这一机制上
2 方法
2.1 问题定义
一个适应性免疫受体库 (adaptive immune receptor repertoires, AIRR) 包含大量的自适应免疫受体 (AIR),也就是通常所说的包与实例。令 { I R 1 , I R 2 , … , I R N } \{ IR_1,IR_2,\dots,IR_N \} {
IR1,IR2,…,IRN}表示 N N N个包 (AIRR),每个包包含 M M M个实例 I R i 1 , I R i 2 , … , I R i M IR_i^1,IR_i^2,\dots,IR_i^M IRi1,IRi2,…,IRiM。注意不同包中的实例数量差异悬殊。每个包的标签记为 Y i ∈ { 0 , … , C } Y_i\in\{0,\dots,C\} Yi∈{
0,…,C},其中 C C C为类别数。
与传统MIL不同之处在于,每个AIR对应一个频率值,对于包而言,记为 { f r e i 1 , f r e i 2 , … , f r e i M } \{fre_i^1,fre_i^2,\dots,fre_i^M\} {
frei1,frei2,…,freiM},表示对某种抗原的免疫反应的强度。
本文的目的是建立一个映射函数 Y i = F ( I R i ) Y_i=F(IR_i) Yi=F(IRi),即获得每个包 I R i IR_i IRi的免疫状态。与传统的实例级方法类似,实例的标签被初始化为包的标签,不同之处在于,这些伪标签将通过标签消歧模块逐步更新。
2.2 模型架构
LaDM 3 ^3 3IL的整体架构如图2,接下来将依次介绍特征提取、标签消歧、聚合模块,并在最后说明损失函数。
图2:LaDM 3 ^3 3IL总体架构,其包含特征提取器、标签消歧模块,以及一个汇聚模块。a) 特征提取器:一个预训练SC-AIR-BERT作为序列编码器;一个可训练嵌入层作为基因编码器,以转换VDJ基因片段。门控注意力机制和一个张量融合模块用于整合学习的基因与氨基酸 (AA) 序列特征;b) 标签去混淆 (以二分类为例):首先获取原型,其用于表示每个类的特征嵌入。在训练期间,top-K个实例被选择用于更新原型中相应类的嵌入,并根据实例特征嵌入与原型特征嵌入之间的相似性来调整每个实例的标签;以及c) 聚合模块:通过多层感知机获得每个实例的预测后,聚合模块通过将这些预测与其相应的频率相乘来整合这些预测,然后对结果进行归一化以生成包级别的预测。
2.2.1 特征提取
为了获取每个实例AIR的全面表示,需要整合来自AA序列和VDJ基因片段的信息:
- 基因编码器:利用一个可训练嵌入层将VDJ基因片段转换为数值表示 h g h_g hg,其包含两个独立的片段,分别编码了来自V基因片段和J基因片段的信息,维度则分别为16和8。注意D基因被排除,因为大部分实例种不包含;
- 预训练序列编码SC-AIR-BERT:生成实例所对应的AA序列的表示 h s h_s hs,其维度为512;
- 门控注意力机制:输出两个模态的信息 o g o_g og和 o s o_s os:
o g = σ ( z g ) ⋅ h g ′ (1) \tag{1} o_g=\sigma(z_g)\cdot h_g' og=σ(zg)⋅hg′(1)其中 h g ′ = ReLU ( W g h g + b g ) h_g'=\text{ReLU}(W_gh_g+b_g) hg′=ReLU(Wghg+bg)是一个线性变换、 z g = h g T W g s h s + b g s z_g=h_g^TW_{gs}h_s+b_{gs} zg=hgTWgshs+bgs是一个双线性变换,以及 σ \sigma σ是sigmoid函数。同理可以计算得到:
o s = σ ( z s ) ⋅ h s ′ (2) \tag{2} o_s=\sigma(z_s)\cdot h_s' os=σ(zs)⋅hs′(2)其中 h s ′ = ReLU ( W s h s + b s ) h_s'=\text{ReLU}(W_sh_s+b_s) hs′=ReLU(Wshs+bs), z s = h s T W s g h g + b s g z_s=h_s^TW_{sg}h_g+b_{sg} zs=hsTWsghg+bsg; - 张量融合模块:整合 o g o_g og和 o s o_s os为最终的表示 h h h:
h = ReLU ( W fusion ⋅ ( o g ⊗ o s ) + b fusion ) , (3) \tag{3} h=\text{ReLU}(W_\text{fusion}\cdot(o_g\otimes o_s)+b_\text{fusion}), h=ReLU(Wfusion⋅(og⊗os)+bfusion),(3)其中 ⊗ \otimes ⊗表示Kronecker积。
2.2.2 标签消歧
标签消歧模块用于处理实例的不准确监督问题:
- 每个实例通过特征提取器后,可以获得一个实例级预测:
p i j = softmax ( F C receptor ( h i j ) ) , (4) \tag{4} p_i^j=\text{softmax}(FC_\text{receptor}(h_i^j)), pij=softmax(FCreceptor(hij)),(4)其中 F C receptor FC_\text{receptor} FCreceptor表示分类器; - 在每一个训练轮次 e e e,将从每个类别种选取实例级预测大于给定阈值 θ \theta θ的实例:
k e c − r e c e p t o r = { h i k , e , c ∣ p i k , e , c > θ , c ∈ { 0 , … , C } , k ∈ { 0 , K } , i ∈ { 0 , N } } . (5) \tag{5} \begin{aligned} kec-receptor=&\{ h_i^{k,e,c} | p_i^{k,e,c}>\theta,c\in\{0,\dots,C\},\\ &k\in\{0,K\},i\in\{0,N\}\}. \end{aligned} kec−receptor={
hik,e,c∣pik,e,c>θ,c∈{
0,…,C},k∈{
0,K},i∈{
0,N}}.(5) - 基于动量法生成原型,即原型中每个类别 c c c将在下一轮次更新:
E prototype = Normalize ( λ ⋅ E prototype c , e + ( 1 − λ ) ⋅ h i k , e , c ) , h i k , e , c ∈ k e c − r e c e p t o r , c ∈ { 0 , C } , (6) \tag{6} \begin{aligned} E_\text{prototype}=&\text{Normalize}(\lambda\cdot E_\text{prototype}^{c,e}+(1-\lambda)\cdot h_i^{k,e,c}),\\ &h_i^{k,e,c}\in kec-receptor,c\in\{0,C\}, \end{aligned} Eprototype=Normalize(λ⋅Eprototypec,e+(1−λ)⋅hik,e,c),hik,e,c∈kec−receptor,c∈{
0,C},(6)其中 λ ∈ [ 0 , 1 ] \lambda\in[0,1] λ∈[0,1]是动量系数; - 每个实例的标签 Y i j Y_i^j Yij被初始化为包的标签 Y i Y_i Yi,并通过实例与原型相似性的Onehot编码来更新:
Y i j , e + 1 = γ ⋅ Y i j , e + ( 1 − γ ) ⋅ Onehot ( s i m i j , e ) , γ = e ( E p o c h e n d − E p o c h s t a r t ) E p o c h + E p o c h s t a r t , (7) \tag{7} \begin{aligned} &Y_i^{j,e+1}=\gamma\cdot Y_i^{j,e}+(1-\gamma)\cdot \text{Onehot}(sim_i^{j,e}),\\ &\gamma=\frac{e(Epoch_{end}-Epoch_{start})}{Epoch}+Epoch_{start}, \end{aligned} Yij,e+1=γ⋅Yij,e+(1−γ)⋅Onehot(simij,e),γ=Epoche(Epochend−Epochstart)+Epochstart,(7)其中
s i m i j , e = arg max c ( E prototype e ⋅ ( h i j , e ) T ) . (8) \tag{8} sim_i^{j,e}=\argmax_c(E_\text{prototype}^e\cdot(h_i^{j,e})^T). simij,e=cargmax(Eprototypee⋅(hij,e)T).(8)
2.2.3 聚合模块
为了生成每个免疫库,也就是包的预测 p i p_i pi,我们结合实例预测 p i j p_i^j pij及其频率 f r e i j fre_i^j freij:
p i = ∑ j = 1 M ( p i j ⋅ f r e i j ) . (9) \tag{9} p_i=\sum_{j=1}^M(p_i^j\cdot fre_i^j). pi=j=1∑M(pij⋅freij).(9)最后通过最大最小标准化获取包预测。
2.2.4 损失函数
训练阶段分为热身阶段和标签消歧阶段:
- 热身阶段:原型更新和标签消歧被暂停,只需计算实例预测和初始标签之间的交叉熵损失:
L receptor = − 1 N M ∑ i = 1 N ∑ j = 1 M ∑ c = 0 C Y i j , c ⋅ log ( p i j , c ) . (10) \tag{10} L_\text{receptor}=-\frac{1}{NM}\sum_{i=1}^N\sum_{j=1}^M\sum_{c=0}^CY_i^{j,c}\cdot\log(p_i^{j,c}). Lreceptor=−NM1i=1∑Nj=1∑Mc=0∑CYij,c⋅log(pij,c).(10) - 标签消歧阶段:使用标签消歧损失:
L disambiguation = − 1 N M ∑ i = 1 N ∑ j = 1 M ∑ c = 0 C Y i j , c , e ⋅ log ( p i j , c , e ) . (11) \tag{11} L_\text{disambiguation}=-\frac{1}{NM}\sum_{i=1}^N\sum_{j=1}^M\sum_{c=0}^CY_i^{j,c,e}\cdot\log(p_i^{j,c,e}). Ldisambiguation=−NM1i=1∑Nj=1∑Mc=0∑CYij,c,e⋅log(pij,c,e).(11)
3 实验
3.1 数据集
- CMV:包含785个受体库,平均每个243960个受体。通过排除缺失信息,最终选择684个受体库,其中正312,负372。该数据集同时完成受体库分类和相关受体鉴定任务;
- Cancer:训练集30000个肿瘤受体和40000个对照受体,测试集10000个肿瘤受体和19851个对照受体。该数据集重点关注肿瘤受体的鉴定;
3.2 实现细节
- 对于CMV数据集,依据Widrich的方法来划分数据集;
- 对于Cancer数据集,依据Beshnova的方法来划分数据集;
- 对于自己的方法,CMV的训练集比例从10%逐渐增加到60%,Cancer则从20%逐渐增加到60%;
- 优化器选用Adam,基因编码器的学习率设置为 1 e − 3 1e^{-3} 1e−3、序列编码器的学习率设置为 1 e − 4 1e^{-4} 1e−4,以及主模型的学习率设置为 1 e − 3 1e^{-3} 1e−3;
- 训练阶段基于网格搜索和验证集来确定超参数,最终如下:
- 评价指标选用AUC和ACC;
3.3 性能对比
- 肿瘤相关受体鉴定:
- 序列长度对cancer数据集分类的影响:
文章评论