混合ガウス分布の変分ベイズ法による推定

PRML

図示

変分混合ガウス分布.gif

初期クラス数を6に設定しても、最終的に3つのクラスに収束していることがわかる

アルゴリズム

  1. $r_{nk}$を初期化
  2. 三つの統計量を計算
    • $N_k = \sum_{n=1}^N r_{nk}$
    • ${\bar x_k} = \frac{1}{N_k} \sum_{n=1}^N r_{nk}x_n$
    • $S_k = \frac{1}{N_k} \sum_{n=1}^N r_{nk}(x_n – {\bar x_k})(x_n – {\bar x_k})^T$
  3. Mstep: $q(\pi) = Dir(\pi|\alpha), q(\mu_k, \Lambda_k) = N(\mu_k|m_k, (\beta_k \Lambda_k)^{-1})W(\Lambda_k|W_k, \nu_k)$を求める。
    • $\alpha_k = \alpha_0 + N_k$
    • $\beta_k = \beta_0 + N_k$
    • $m_k = \frac{1}{\beta_k}(\beta_0 m_0 + N_k {\bar x_k})$
    • $W_k^{-1} = W_0^{-1} + N_k S_k + \frac{\beta_0 N_k}{\beta_0 + N_k}({\bar x_k} – m_0)({\bar x_k} – m_0)^T$
    • $\nu_k = \nu_0 + N_k$
  4. Estep: $q(Z) = \Pi_{n=1}^N \Pi_{k=1}^K r_{nk}^{z_{nk}}$を求める
    • $r_{nk} = \frac{\rho_{nk}}{\sum_{j=1}^K \rho_{nj}}$
    • $\ln\rho_{nk} = E[\ln\pi_k] + \frac{1}{2} E[\ln |\Lambda_k|] – \frac{D}{2}\ln(2\pi) – \frac{1}{2}E_{\mu_k, \Lambda_k}[(x_n – \mu_k)^T \Lambda_k (x_n – \mu_k)]$
    • $E_{\mu_k, \Lambda_k}[(x_n – \mu_k)^T \Lambda_k (x_n – \mu_k)] = D\beta_k^{-1} + \nu_k(x_n – m_k)^TW_k(x_n – m_k)$
    • $E[\ln|\Lambda_k|] = \sum_{i=1}^D \psi(\frac{\nu_k + 1 – i}{2}) + D \ln 2 + \ln|W_k|$
    • $E[\ln \pi_k] = \psi(\alpha_k) – \psi({\hat \alpha})$
    • ${\hat \alpha} = \sum_k \alpha_k$
  5. 収束するまで2~4をまわす

実装

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.special import digamma
import matplotlib.cm as cm
plt.style.use("ggplot")

#次元
D = 2
#データ数
N = 2000


# 実際の値
mu1 = [0, 1]
sigma1 = 0.2 * np.eye(D)
N1 = int(N*0.3)
mu2 = [-1, -1]
sigma2 = 0.1 * np.eye(D)
N2 = int(N*0.5)
mu3 = [1, -1]
sigma3 = 0.1 * np.eye(D)
N3 = int(N*0.2)

plt.figure(figsize=(5, 5))
data = np.concatenate([np.random.multivariate_normal(mu1, sigma1, N1), 
                    np.random.multivariate_normal(mu2, sigma2, N2),
                    np.random.multivariate_normal(mu3, sigma3, N3)
                   ])
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
plt.scatter(data[:, 0], data[:, 1], s=10)
plt.show()
w8dO9h9H66IZwAAAABJRU5ErkJggg==.png
#初期クラス数
K = 6
#初期値
mu = np.array([[0., -0.5],[0., 0.5], [1., 0.5], [-1., -0.5], [-1, -1.5], [1, 1.5]])
S = np.array([0.1 * np.eye(2) for k in range(K)])

#事前分布のパラメータ
alpha_0 = 1e-3
beta_0 = 1e-3
m_0 = np.zeros((K, D))
nu_0 = 1
W_0 = np.eye(D)

#初期パラメータ
W_k = np.zeros((K, D, D))
E_mu_lam = np.zeros((N, K))

def multi_gauss(x, y, mu, sigma):
    return stats.multivariate_normal(mu, sigma).pdf(np.array([x, y]))

#1: r_nkの初期化
r = np.ones([N, K]) / K
pi = np.ones(K) / K
g = np.zeros((N, K))
for k in range(K):
    g[:, k] = np.vectorize(lambda x, y: pi[k] * multi_gauss(x, y, mu[k], S[k]))(data[:, 0], data[:, 1])
for k in range(K):
    r[:, k] = g[:, k] / g.sum(1)

# 図示
X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
cmap_colors = [cm.spring, cm.summer, cm.autumn, cm.winter, cm.Reds_r, cm.Dark2]
colors = ["pink", "green", "orange", "blue", "red", "black"]
plt.figure(figsize=(5, 5))
for k in range(K):
    Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
    plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)


plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), alpha=0.3, s=10)
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
init_title = "iter: 0"
plt.title(init_title)
#plt.savefig("data/" + init_title + ".png")
plt.show()
n9OnT9Pc3KzKdIpEGtT0IVbdmfJhbduhUGhT+oCIjkz3A8h8XxDVoLYXa+NDT09P2nyQK1AkEokEDY4MJRKJZDOQwVAikUiQwVAikUgAGQwlEokEkMFQIpFIABkMJRKJBJDBUCKRSAD4HwgVxm0OO01kAAAAAElFTkSuQmCC.png
for i in range(20):
    #2: 三つの統計量を計算
    N_k = r.sum(0)
    mu = r.T.dot(data) / np.c_[N_k]
    for k in range(K):
        S[k] = (np.c_[r[:, k]] * (data - mu[k])).T.dot(data - mu[k]) / N_k[k]

    #3: Mstep
    alpha = alpha_0 + N_k
    beta = beta_0 + N_k
    m_k = (beta_0 * m_0 + np.c_[N_k] * mu) / np.c_[beta]
    for k in range(K):
        tmp1 = beta_0 * N_k[k] * np.outer(mu[k] - m_0[k], mu[k] - m_0[k]) / (beta_0 + N_k[k])
        tmp2 = LA.inv(W_0) + N_k[k] * S[k] + tmp1
        W_k[k] = LA.inv(tmp2)
    nu_k = nu_0 + N_k

    #4: Estep
    E_ln_lam = digamma(nu_k / 2) + digamma((nu_k - 1) / 2) + D * np.log(2) + np.log([LA.norm(w) for w in W_k])
    E_ln_pi = digamma(alpha) - digamma(alpha.sum())
    for k in range(K):
        E_mu_lam[:, k] = D / beta[k] + nu_k [k] * np.diag((data - m_k[k]).dot(W_k[k]).dot((data - m_k[k]).T))
    ro = np.exp(E_ln_pi + E_ln_lam / 2. - D * np.log(2 * np.pi) / 2. - E_mu_lam / 2.)
    r = ro / np.c_[ro.sum(1)]
    r[r < 1e-10] = 1e-10

    # gifの図を作成
    plt.figure(figsize=(5, 5))
    X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
    pi = np.exp(E_ln_pi)
    for k in range(K):
        Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
        if np.exp(E_ln_pi)[k] > 0.01:
            plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)
    plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), s=10, alpha=0.3)
    plt.xlim(-2.1, 2.1)
    plt.ylim(-2.1, 2.1)
    title = "iter: {}".format(i+1)
    plt.title(title)
    #plt.savefig("data/" + title + ".png")
    plt.show()

コメント

タイトルとURLをコピーしました