deeplearning.ai 总结 - C++实现 lstm forward

deeplearning.ai 总结 - C++实现 lstm forward

flyfish

struct parameters
{
    Eigen::MatrixXd*    Wf;
    Eigen::MatrixXd*    bf;
    Eigen::MatrixXd*    Wi;
    Eigen::MatrixXd*    bi;
    Eigen::MatrixXd*    Wc;
    Eigen::MatrixXd*    bc;
    Eigen::MatrixXd*    Wo;
    Eigen::MatrixXd*    bo;
    Eigen::MatrixXd*    Wy;
    Eigen::MatrixXd*    by;

};

struct cache
{
    parameters* p;
    Eigen::MatrixXd*    a_next;
    Eigen::MatrixXd*    c_next;
    Eigen::MatrixXd*    a_prev;
    Eigen::MatrixXd*    c_prev;
    Eigen::MatrixXd*    ft;
    Eigen::MatrixXd*    it;
    Eigen::MatrixXd*    cct;
    Eigen::MatrixXd*    ot;
    Eigen::MatrixXd*    xt;
};

    void lstm_forward(const Eigen::MatrixXd &x, int batch_row, const Eigen::MatrixXd& a0, parameters p);


    std::tuple<Eigen::MatrixXd*, Eigen::MatrixXd*, Eigen::MatrixXd*, cache> lstm_cell_forward( Eigen::MatrixXd &xt,
        Eigen::MatrixXd &a_prev,  
        Eigen::MatrixXd &c_prev, 
        parameters p);




void lstm_forward(const Eigen::MatrixXd &x,int batch_row, const Eigen::MatrixXd& a0,parameters p)
{
    //3,10,7
    // simple=3,cols=7 ,batch_row=10
    //// 解释一共有3个句子,每个句子7个单词,每个单词10维向量
    int n_x = x.rows();
    int m = batch_row;

    int T_x = x.cols(); //一个batch row * T_x = 一个simple
    int simple_count = n_x / m;

    int n_y = p.Wy->rows();
    int n_a = p.Wy->cols();

    //初始化 a c y 为0
    Eigen::MatrixXd a = Eigen::MatrixXd::Zero(n_a * m, T_x);//三维
    Eigen::MatrixXd c = Eigen::MatrixXd::Zero(n_a * m, T_x);//三维
    Eigen::MatrixXd y = Eigen::MatrixXd::Zero(n_y * m, T_x);//三维

    //假设数据
    //[[[1.62434536 - 0.61175641 - 0.52817175 - 1.07296862] (3,5,4) //3个5行4列 特征4,总体看15行4列 
    //  [0.86540763 - 2.3015387   1.74481176 - 0.7612069]
    //[0.3190391 - 0.24937038  1.46210794 - 2.06014071]
    //[-0.3224172 - 0.38405435  1.13376944 - 1.09989127]
    //[-0.17242821 - 0.87785842  0.04221375  0.58281521]]

    //  [[-1.10061918  1.14472371  0.90159072  0.50249434]
    //  [0.90085595 - 0.68372786 - 0.12289023 - 0.93576943]
    //[-0.26788808  0.53035547 - 0.69166075 - 0.39675353]
    //[-0.6871727 - 0.84520564 - 0.67124613 - 0.0126646]
    //[-1.11731035  0.2344157   1.65980218  0.74204416]]

    //[[-0.19183555 - 0.88762896 - 0.74715829  1.6924546]
    //  [0.05080775 - 0.63699565  0.19091548  2.10025514]
    //[0.12015895  0.61720311  0.30017032 - 0.35224985]
    //[-1.1425182 - 0.34934272 - 0.20889423  0.58662319]
    //[0.83898341  0.93110208  0.28558733  0.88514116]]]
    //切分数据如下
//--------------------------------------------------------------------------------
    //[[-0.61175641 - 2.3015387 - 0.24937038 - 0.38405435 - 0.87785842] //3行 5列
    //  [1.14472371 - 0.68372786  0.53035547 - 0.84520564  0.2344157]
    //[-0.88762896 - 0.63699565  0.61720311 - 0.34934272  0.93110208]]

    //[[1.62434536  0.86540763  0.3190391 - 0.3224172 - 0.17242821]
    //  [-1.10061918  0.90085595 - 0.26788808 - 0.6871727 - 1.11731035]
    //[-0.19183555  0.05080775  0.12015895 - 1.1425182   0.83898341]]



    Eigen::MatrixXd a_next = a0;
    Eigen::MatrixXd c_next = Eigen::MatrixXd::Zero(n_a, m);


    std::cout << "x:\n" << x << std::endl;
    for (int j=0;j<T_x;j++)
    {

        Eigen::MatrixXd input(simple_count, batch_row);
        input = (x.col(j));
        input.resize(batch_row, simple_count);
        input.transposeInPlace();
        std::cout << "input:\n" << input << std::endl;
        lstm_cell_forward(input, a_next, c_next, p);
    }
}

std::tuple<Eigen::MatrixXd*, Eigen::MatrixXd*, Eigen::MatrixXd*, cache>  lstm_cell_forward( Eigen::MatrixXd &xt,  Eigen::MatrixXd &a_prev,  Eigen::MatrixXd &c_prev, parameters p)
{

    int n_x = xt.rows();
    int m = xt.cols();

    int n_y = p.Wy->rows();
    int n_a = p.Wy->cols();


    Eigen::MatrixXd concat(n_a + n_x, m);
    concat << a_prev, xt;

    Eigen::MatrixXd ft = sigmond_forward(matrix_add_bias( *p.Wf * concat, *p.bf));

    std::cout << "ft:\n" << ft << std::endl;
    Eigen::MatrixXd it = sigmond_forward(matrix_add_bias(*p.Wi * concat, *p.bi));
    std::cout << "it:\n" << it << std::endl;
    Eigen::MatrixXd cct = (matrix_add_bias(*p.Wc * concat, *p.bc)).array().tanh();
    std::cout << "cct:\n" << cct << std::endl;
    Eigen::MatrixXd c_next = it.cwiseProduct(cct) + ft.cwiseProduct(c_prev);
    std::cout << "c_next:\n" << c_next << std::endl;
    Eigen::MatrixXd ot = sigmond_forward(matrix_add_bias(*p.Wo * concat, *p.bo));
    std::cout << "ot:\n" << ot << std::endl;

    Eigen::MatrixXd t = (c_next.array().tanh());
    Eigen::MatrixXd a_next = ot.cwiseProduct(t);
    std::cout << "a_next:\n" << a_next << std::endl;

    Eigen::MatrixXd yt_pred = softmax_forward(a_next);
    std::cout << "yt_pred:\n" << yt_pred << std::endl;

    cache c;
    c.p = &p;
    c.a_next= &a_next;
    c.c_next= &c_next;
    c.a_prev= &a_prev;
    c.c_prev = &c_prev;
    c.ft = &ft;
    c.it = &it;
    c.cct = &cct;
    c.ot = &ot;
    c.xt = &xt;

    return std::make_tuple(&a_next, &c_next, &yt_pred, c);
}

调用

    //Eigen::Matrix<double, 5, 5 + 3> Wf;
    Eigen::MatrixXd Wf(5,8);
    Wf << -0.44712856, 1.2245077, 0.40349164, 0.59357852, -1.09491185, 0.16938243, 0.74055645, -0.9537006,
        -0.26621851, 0.03261455, -1.37311732, 0.31515939, 0.84616065, -0.85951594, 0.35054598, -1.31228341,
        -0.03869551, -1.61577235, 1.12141771, 0.40890054, -0.02461696, -0.77516162, 1.27375593, 1.96710175,
        -1.85798186, 1.23616403, 1.62765075, 0.3380117, -1.19926803, 0.86334532, -0.1809203, -0.60392063,
        -1.23005814, 0.5505375, 0.79280687, -0.62353073, 0.52057634, -1.14434139, 0.80186103, 0.0465673;

    //Eigen::VectorXd bf(5);//相当于Eigen::Matrix<double, 5, 1>  bf;
    Eigen::MatrixXd bf(5,1);
    bf << -0.18656977,
        -0.10174587,
        0.86888616,
        0.75041164,
        0.52946532;

    //Eigen::Matrix<double, 5, 5 + 3> Wi;
    Eigen::MatrixXd Wi(5,8);
    Wi << 0.13770121, 0.07782113, 0.61838026, 0.23249456, 0.68255141, -0.31011677, -2.43483776, 1.0388246,
        2.18697965, 0.44136444, -0.10015523, -0.13644474, -0.11905419, 0.01740941, -1.12201873, -0.51709446,
        -0.99702683, 0.24879916, -0.29664115, 0.49521132, -0.17470316, 0.98633519, 0.2135339, 2.19069973,
        -1.89636092, -0.64691669, 0.90148689, 2.52832571, -0.24863478, 0.04366899, -0.22631424, 1.33145711,
        -0.28730786, 0.68006984, -0.3198016, -1.27255876, 0.31354772, 0.50318481, 1.29322588, -0.11044703;

    //Eigen::Matrix<double, 5, 1> bi;
    Eigen::MatrixXd bi(5,1);
    bi << -0.61736206,
        0.5627611,
        0.24073709,
        0.28066508,
        -0.0731127;


    //Eigen::Matrix<double, 5, 5 + 3> Wo;
    Eigen::MatrixXd Wo(5,8);
    Wo << 1.16033857, 0.36949272, 1.90465871, 1.1110567, 0.6590498, -1.62743834, 0.60231928, 0.4202822,
        0.81095167, 1.04444209, -0.40087819, 0.82400562, -0.56230543, 1.95487808, -1.33195167, -1.76068856,
        -1.65072127, -0.89055558, -1.1191154, 1.9560789, -0.3264995, -1.34267579, 1.11438298, -0.58652394,
        -1.23685338, 0.87583893, 0.62336218, -0.43495668, 1.40754, 0.12910158, 1.6169496, 0.50274088,
        1.55880554, 0.1094027, -1.2197444, 2.44936865, -0.54577417, -0.19883786, -0.7003985, -0.20339445;

    //Eigen::Matrix<double, 5, 1> bo;
    Eigen::MatrixXd bo(5,1);
    bo << 0.24266944,
        0.20183018,
        0.66102029,
        1.79215821,
        -0.12046457;

    //Eigen::Matrix<double, 5, 5 + 3> Wc;
    Eigen::MatrixXd Wc(5,8);
    Wc << -1.23312074e+00, -1.18231813e+00, -6.65754518e-01, -1.67419581e+00, 8.25029824e-01, -4.98213564e-01, -3.10984978e-01, -1.89148284e-03,
        -1.39662042e+00, -8.61316361e-01, 6.74711526e-01, 6.18539131e-01, -4.43171931e-01, 1.81053491e+00, -1.30572692e+00, -3.44987210e-01,
        -2.30839743e-01, -2.79308500e+00, 1.93752881e+00, 3.66332015e-01, -1.04458938e+00, 2.05117344e+00, 5.85662000e-01, 4.29526140e-01,
        -6.06998398e-01, 1.06222724e-01, -1.52568032e+00, 7.95026094e-01, -3.74438319e-01, 1.34048197e-01, 1.20205486e+00, 2.84748111e-01,
        2.62467445e-01, 2.76499305e-01, -7.33271604e-01, 8.36004719e-01, 1.54335911e+00, 7.58805660e-01, 8.84908814e-01, -8.77281519e-01;


    //Eigen::Matrix<double, 5, 1> bc;
    Eigen::MatrixXd bc(5,1);
    bc << -0.86778722,
        -1.44087602,
        1.23225307,
        -0.25417987,
        1.39984394;

    //Eigen::Matrix<double, 2, 5> Wy;
    Eigen::MatrixXd Wy(2,5);
    Wy << -0.78191168, -0.43750898, 0.09542509, 0.92145007, 0.0607502,
        0.21112476, 0.01652757, 0.17718772, -1.11647002, 0.0809271;


    //Eigen::Matrix<double, 2, 1> by = Eigen::MatrixXd::Random(2, 1);
    Eigen::MatrixXd by(2,1);
    by << -0.18657899,
        -0.05682448;

    parameters p;
    p.Wf= &Wf;

    p.  bf=&bf;
    p.  Wi=&Wi;
    p.  bi=&bi;
    p.  Wc=&Wc;
    p.  bc=&bc;
    p.  Wo=&Wo;
    p.  bo=&bo;
    p.  Wy=&Wy;
    p.  by=&by;

    Eigen::MatrixXd xt= Eigen::MatrixXd::Random(30, 7);

    int batch_row = 10;//3,10,7
    Eigen::MatrixXd a0 = Eigen::MatrixXd::Random(5, 10);
    lstm_forward(xt, batch_row,a0, p);

猜你喜欢

转载自blog.csdn.net/flyfish1986/article/details/80156875