谁做过 EM算法 java实现

发布网友 发布时间:2022-04-22 09:11

我来回答

1个回答

热心网友 时间:2023-09-21 13:07

参考:

package nlp;
/**
 * @author Orisun
 * date 2011-10-22
 */
import java.util.ArrayList;

public class BaumWelch {

    int M; // 隐藏状态的种数
    int N; // 输出活动的种数
    double[] PI; // 初始状态概率矩阵
    double[][] A; // 状态转移矩阵
    double[][] B; // 混淆矩阵

    ArrayList<Integer> observation = new ArrayList<Integer>(); // 观察到的集合
    ArrayList<Integer> state = new ArrayList<Integer>(); // 中间状态集合
    int[] out_seq = { 2, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1,
            1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1 }; // 测试用的观察序列
    int[] hidden_seq = { 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1,
            1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1 }; // 测试用的隐藏状态序列
    int T = 32; // 序列长度为32

    double[][] alpha = new double[T][]; // 向前变量
    double PO;
    double[][] beta = new double[T][]; // 向后变量
    double[][] gamma = new double[T][];
    double[][][] xi = new double[T - 1][][];

    // 初始化参数。Baum-Welch得到的是局部最优解,所以初始参数直接影响解的好坏
    public void initParameters() {
        M = 2;
        N = 2;
        PI = new double[M];
        PI[0] = 0.5;
        PI[1] = 0.5;
        A = new double[M][];
        B = new double[M][];
        for (int i = 0; i < M; i++) {
            A[i] = new double[M];
            B[i] = new double[N];
        }
        A[0][0] = 0.8125;
        A[0][1] = 0.1875;
        A[1][0] = 0.2;
        A[1][1] = 0.8;
        B[0][0] = 0.875;
        B[0][1] = 0.125;
        B[1][0] = 0.25;
        B[1][1] = 0.75;

        observation.add(1);
        observation.add(2);
        state.add(1);
        state.add(2);

        for (int t = 0; t < T; t++) {
            alpha[t] = new double[M];
            beta[t] = new double[M];
            gamma[t] = new double[M];
        }
        for (int t = 0; t < T - 1; t++) {
            xi[t] = new double[M][];
            for (int i = 0; i < M; i++)
                xi[t][i] = new double[M];
        }
    }

    // 更新向前变量
    public void updateAlpha() {
        for (int i = 0; i < M; i++) {
            alpha[0][i] = PI[i] * B[i][observation.indexOf(out_seq[0])];
        }
        for (int t = 1; t < T; t++) {
            for (int i = 0; i < M; i++) {
                alpha[t][i] = 0;
                for (int j = 0; j < M; j++) {
                    alpha[t][i] += alpha[t - 1][j] * A[j][i];
                }
                alpha[t][i] *= B[i][observation.indexOf(out_seq[t])];
            }
        }
    }

    // 更新观察序列出现的概率,它在一些公式中当分母
    public void updatePO() {
        for (int i = 0; i < M; i++)
            PO += alpha[T - 1][i];
    }

    // 更新向后变量
    public void updateBeta() {
        for (int i = 0; i < M; i++) {
            beta[T - 1][i] = 1;
        }
        for (int t = T - 2; t >= 0; t--) {
            for (int i = 0; i < M; i++) {
                for (int j = 0; j < M; j++) {
                    beta[t][i] += A[i][j]
                            * B[j][observation.indexOf(out_seq[t + 1])]
                            * beta[t + 1][j];
                }
            }
        }
    }

    // 更新xi
    public void updateXi() {
        for (int t = 0; t < T - 1; t++) {
            double frac = 0.0;
            for (int i = 0; i < M; i++) {
                for (int j = 0; j < M; j++) {
                    frac += alpha[t][i] * A[i][j]
                            * B[j][observation.indexOf(out_seq[t + 1])]
                            * beta[t + 1][j];
                }
            }
            for (int i = 0; i < M; i++) {
                for (int j = 0; j < M; j++) {
                    xi[t][i][j] = alpha[t][i] * A[i][j]
                            * B[j][observation.indexOf(out_seq[t + 1])]
                            * beta[t + 1][j] / frac;
                }
            }
        }
    }

    // 更新gamma
    public void updateGamma() {
        for (int t = 0; t < T - 1; t++) {
            double frac = 0.0;
            for (int i = 0; i < M; i++) {
                frac += alpha[t][i] * beta[t][i];
            }
            // double frac = PO;
            for (int i = 0; i < M; i++) {
                gamma[t][i] = alpha[t][i] * beta[t][i] / frac;
            }
            // for(int i=0;i<M;i++){
            // gamma[t][i]=0;
            // for(int j=0;j<M;j++)
            // gamma[t][i]+=xi[t][i][j];
            // }
        }
    }

    // 更新状态概率矩阵
    public void updatePI() {
        for (int i = 0; i < M; i++)
            PI[i] = gamma[0][i];
    }

    // 更新状态转移矩阵
    public void updateA() {
        for (int i = 0; i < M; i++) {
            double frac = 0.0;
            for (int t = 0; t < T - 1; t++) {
                frac += gamma[t][i];
            }
            for (int j = 0; j < M; j++) {
                double dem = 0.0;
                // for (int t = 0; t < T - 1; t++) {
                // dem += xi[t][i][j];
                // for (int k = 0; k < M; k++)
                // frac += xi[t][i][k];
                // }
                for (int t = 0; t < T - 1; t++) {
                    dem += xi[t][i][j];
                }
                A[i][j] = dem / frac;
            }
        }
    }

    // 更新混淆矩阵
    public void updateB() {
        for (int i = 0; i < M; i++) {
            double frac = 0.0;
            for (int t = 0; t < T; t++)
                frac += gamma[t][i];
            for (int j = 0; j < N; j++) {
                double dem = 0.0;
                for (int t = 0; t < T; t++) {
                    if (out_seq[t] == observation.get(j))
                        dem += gamma[t][i];
                }
                B[i][j] = dem / frac;
            }
        }
    }

    // 运行Baum-Welch算法
    public void run() {
        initParameters();
        int iter = 22; // 迭代次数
        while (iter-- > 0) {
            // E-Step
            updateAlpha();
            // updatePO();
            updateBeta();
            updateGamma();
            updatePI();
            updateXi();
            // M-Step
            updateA();
            updateB();
        }
    }

    public static void main(String[] args) {
        BaumWelch bw = new BaumWelch();
        bw.run();
        System.out.println("训练后的初始状态概率矩阵:");
        for (int i = 0; i < bw.M; i++)
            System.out.print(bw.PI[i] + "\t");
        System.out.println();
        System.out.println("训练后的状态转移矩阵:");
        for (int i = 0; i < bw.M; i++) {
            for (int j = 0; j < bw.M; j++) {
                System.out.print(bw.A[i][j] + "\t");
            }
            System.out.println();
        }
        System.out.println("训练后的混淆矩阵:");
        for (int i = 0; i < bw.M; i++) {
            for (int j = 0; j < bw.N; j++) {
                System.out.print(bw.B[i][j] + "\t");
            }
            System.out.println();
        }
    }
}

声明声明:本网页内容为用户发布,旨在传播知识,不代表本网认同其观点,若有侵权等问题请及时与本网联系,我们将在第一时间删除处理。E-MAIL:11247931@qq.com