当前位置:网站首页>論文閱讀23 - Mixture Density Networks(MDN)混合密度網絡理論分析

論文閱讀23 - Mixture Density Networks(MDN)混合密度網絡理論分析

2021-09-15 05:07:51 程序員大本營

Mixture Density Networks

最近看論文經常會看到在模型中引入不確定性(Uncertainty)。尤其是MDN(Mixture Density Networks)在World Model這篇文章多次提到。之前只是了解了個大概。翻了翻原版論文和一些相關資料進行了整理。

1. 直觀理解:

混合密度網絡通常作為神經網絡的最後處理部分。將某種分布(通常是高斯分布)按照一定的權重進行疊加,從而擬合最終的分布。

如果選擇高斯分布的MDN,那麼它和GMM(高斯混合模型 Gaussian Mixture Model)有著相同的效果。但是他們有著很明顯的區別:

  • MDN的均值方差每個模型的權重是通過神經網絡產生的,利用最大似然估計作為Loss函數進行反向傳播從而確定網絡的權重(也就是確定一個較好的高斯分布參數)

  • GMM的均值方差每個模型的權重是通過估計出來的,通常使用EM算法來通過不斷迭代確定。

    GMM的詳解以及為什麼要用EM而不是極大似然估計來優化參數,請見這個博客

總之,MDN的思想與GMM一樣,將模型混合的思想與神經網絡相結合。在回歸問題上通常都有很好的錶現。例如,論文中提到的一個翻轉的x,t翻轉的例子:

  1. 如果x是訓練數據,t是我們的label:

    [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-LWJi1O1O-1605340386538)(Untitled.assets/image-20201114103332416.png)]

    普通的神經網絡,使用sum-of-squares error作為loss可以得到一個較好的擬合效果。

  2. 同樣的數據,將x和t的數據翻轉(原來x的數據作為標簽,原來t的數據作為訓練集, tmp = x, x = t, t = tmp):

    [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-8d9pbQRS-1605340386540)(Untitled.assets/image-20201114103606112.png)]

    使用sum-of-squares error作為loss似乎並沒有捕捉到我們的走勢。

  3. MDN效果如何呢

    先上效果圖(來自原版論文)。下圖繪制的是可能性最大的點(分布的均值)。可見基本上可以捕捉到這個趨勢。

    [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-oPgn4RpM-1605340386543)(Untitled.assets/image-20201114140657278.png)]

    在輸出的分布內進行采樣獲取預測,圖片來自

    png

2. 算法細節

2.1. 結構

參數化錶示:

image-20201114142501747

CCC :要混合的分布個數。是用戶需要制定的參數。例如我們需要混合5個高斯分布作為最終結果,那麼C = 5;

α\alphaα :每個分布的權重參數。網絡輸出的參數

DDD: 某一種被混合的分布, 如果是高斯分布,那麼KaTeX parse error: Undefined control sequence: \cal at position 1: \̲c̲a̲l̲ ̲D 就應該用 NNN錶示。

λ\lambdaλ:分布的一些參數,高斯分布則包括μ\muμσ\sigmaσ網絡輸出的參數

需要注意的是:混合的分布可以是任意的。

以高斯分布為例,網絡結構如下:

image-20201114144011352

  • α\alphaα (alpha)的和應該等於1,即∑cCαc=1\sum^{C}_{c} \alpha_c = 1cCαc=1。 所以我們可以在使用softmax**函數來解决。
  • σ\sigmaσ(sigma)>0。 可以保證這個的方法有很多,在Mixture Density Networks中使用指數**:σ=exp(z)\sigma = exp(z)σ=exp(z)。指數可能會引起數值不穩定,出現無窮大。可以使用變種的ELU [3],即σ=ELU(σ)+1\sigma = ELU(\sigma)+1σ=ELU(σ)+1
  • μ\muμ 的範圍是否要確定區間,可以根據實際問題。例如價格預測,不可能出現負的,就可以選擇相關的**函數來固定區間大於0.

2.2 Loss設計:

損失函數使用的極大似然估計。極大似然估計認為我們采樣出來的都是那些出現概率最大的數。所以我們希望我們需要最大化的似然函數為(這裏使用了平均值,即每個分布的似然函數大小):

極大似然估計公式:L(θ)=L(x1,x2...xn;θ)=∏i=1np(xi;θ)L(\theta) = L(x_1,x_2...x_n ; \theta) = \prod_{i = 1 } ^n p(x_i; \theta)L(θ)=L(x1,x2...xn;θ)=i=1np(xi;θ)。用多個分布混合,則p(xi;θ)=∑kKakpk(xi;θ)p(x_i;\theta) = \sum_k ^K a_k p_k(x_i ; \theta)p(xi;θ)=kKakpk(xi;θ)。 下式中 xix_ixiyn∣xny_n|x_nynxn

L(θ)=1N∏nN∑kKakpk(yn∣xn)ln(L(θ))=1N∑nNlog⁡{∑kKαkpk(yn∣xn)}L(\theta) = \frac{1}{N} \prod_n ^N \sum_k ^K a_k p_k(y_n|x_n) \\ ln(L(\theta)) =\frac{1}{N} \sum_n ^N \log \{ \sum_k ^K \alpha_k p_k(y_n|x_n)\}L(θ)=N1nNkKakpk(ynxn)ln(L(θ))=N1nNlog{kKαkpk(ynxn)}

N 樣本總數

K 分布的數量

aka_kak 是當前分布的權重

pkp_kpk 是當前分布的概率

$ \sum_k ^K a_k p_k(y_n|x_n)$ 就是xnx_nxn樣本出現的概率。對應似然函數中的p(xi;θ)p(x_i; \theta)p(xi;θ)。 是k個分布按照權重α\alphaα累加的結果。

優化器一般都是梯度下降,用來最小化目標函數,所以我們要在上式加一個負號,作為優化函數,這樣就是梯度上昇最大化上式。
Loss(θ)=−ln(L(θ))Loss(\theta) = -ln(L(\theta))Loss(θ)=ln(L(θ))
如果是N個高斯分布,那麼我們的損失函數:
Loss(θ)=−1N∑1Nlog⁡{∑kαkN(yn∣μk,σk2)}Loss(\theta) = -\frac{1}{N} \sum_1 ^N \log \{\sum_k \alpha_k N(y_n|\mu_k,\sigma^2_k)\}Loss(θ)=N11Nlog{kαkN(ynμk,σk2)}

N(y∣μ,σ2)=12πσ2e−(x−μ)22σ2N(y|\mu,\sigma^2) = \frac{1}{\sqrt{2 \pi \sigma^2}} e^{\frac{-(x-\mu)^2}{2\sigma^2}}N(yμ,σ2)=2πσ2 1e2σ2(xμ)2

3. 總結

MDN實現簡單,而且可以直接模塊化的連接到神經網絡的後端。他的結果可以得到一個概率範圍,相對有deterministic類只輸出一個結果,往往有更好的健壯性。[3][4]中有相關代碼實現。

4. reference:

[1]. Christopher M. Bishop, Mixture Density Networks (1994)

[2]. Blog-詳解EM算法與混合高斯模型(Gaussian mixture model, GMM)

[3]. Blog-A Hitchhiker’s Guide to Mixture Density Networks

[4]. Blog-Mixture Density Networks

版权声明
本文为[程序員大本營]所创,转载请带上原文链接,感谢
https://chowdera.com/2021/09/20210915050047077t.html

随机推荐