当前位置: 首页>>代码示例>>C++>>正文


C++ Trsm函数代码示例

本文整理汇总了C++中Trsm函数的典型用法代码示例。如果您正苦于以下问题:C++ Trsm函数的具体用法?C++ Trsm怎么用?C++ Trsm使用的例子?那么, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了Trsm函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。

示例1: PushCallStack

inline void
SolveAfterCholesky
( UpperOrLower uplo, Orientation orientation, 
  const DistMatrix<F>& A, DistMatrix<F>& B )
{
#ifndef RELEASE
    PushCallStack("SolveAfterLU");
    if( A.Grid() != B.Grid() )
        throw std::logic_error("{A,B} must be distributed over the same grid");
    if( A.Height() != A.Width() )
        throw std::logic_error("A must be square");
    if( A.Height() != B.Height() )
        throw std::logic_error("A and B must be the same height");
#endif
    if( B.Width() == 1 )
    {
        if( uplo == LOWER )
        {
            if( orientation == TRANSPOSE )
                Conj( B );
            Trsv( LOWER, NORMAL, NON_UNIT, A, B );
            Trsv( LOWER, ADJOINT, NON_UNIT, A, B );
            if( orientation == TRANSPOSE )
                Conj( B );
        }
        else
        {
            if( orientation == TRANSPOSE )
                Conj( B );
            Trsv( UPPER, ADJOINT, NON_UNIT, A, B );
            Trsv( UPPER, NORMAL, NON_UNIT, A, B );
            if( orientation == TRANSPOSE )
                Conj( B );
        }
    }
    else
    {
        if( uplo == LOWER )
        {
            if( orientation == TRANSPOSE )
                Conj( B );
            Trsm( LEFT, LOWER, NORMAL, NON_UNIT, F(1), A, B );
            Trsm( LEFT, LOWER, ADJOINT, NON_UNIT, F(1), A, B );
            if( orientation == TRANSPOSE )
                Conj( B );
        }
        else
        {
            if( orientation == TRANSPOSE )
                Conj( B );
            Trsm( LEFT, UPPER, ADJOINT, NON_UNIT, F(1), A, B );
            Trsm( LEFT, UPPER, NORMAL, NON_UNIT, F(1), A, B );
            if( orientation == TRANSPOSE )
                Conj( B );
        }
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
开发者ID:jimgoo,项目名称:Elemental,代码行数:60,代码来源:SolveAfterCholesky.hpp

示例2: LogicError

void SolveAfter
( Orientation orientation, 
  const Matrix<F>& A,
  const Matrix<F>& householderScalars, 
  const Matrix<Base<F>>& signature,
  const Matrix<F>& B,       
        Matrix<F>& X )
{
    DEBUG_CSE
    const Int m = A.Height();
    const Int n = A.Width();
    if( m > n )
        LogicError("Must have full row rank");

    // TODO: Add scaling
    auto AL = A( IR(0,m), IR(0,m) );
    if( orientation == NORMAL )
    {
        if( m != B.Height() )
            LogicError("A and B do not conform");

        // Copy B into X
        X.Resize( n, B.Width() );
        auto XT = X( IR(0,m), ALL );
        auto XB = X( IR(m,n), ALL );
        XT = B;
        Zero( XB );

        // Solve against L (checking for singularities)
        Trsm( LEFT, LOWER, NORMAL, NON_UNIT, F(1), AL, XT, true );

        // Apply Q' to X 
        lq::ApplyQ( LEFT, ADJOINT, A, householderScalars, signature, X );
    }
    else // orientation in {TRANSPOSE,ADJOINT}
    {
        if( n != B.Height() )
            LogicError("A and B do not conform");

        // Copy B into X
        X = B;

        if( orientation == TRANSPOSE )
            Conjugate( X );

        // Apply Q to X
        lq::ApplyQ( LEFT, NORMAL, A, householderScalars, signature, X );

        // Shrink X to its new height
        X.Resize( m, X.Width() );

        // Solve against L' (check for singularities)
        Trsm( LEFT, LOWER, ADJOINT, NON_UNIT, F(1), AL, X, true );

        if( orientation == TRANSPOSE )
            Conjugate( X );
    }
}
开发者ID:jeffhammond,项目名称:Elemental,代码行数:58,代码来源:SolveAfter.hpp

示例3: entry

inline void
SolveAfter
( UpperOrLower uplo, Orientation orientation, 
  const DistMatrix<F>& A, DistMatrix<F>& B )
{
#ifndef RELEASE
    CallStackEntry entry("cholesky::SolveAfter");
    if( A.Grid() != B.Grid() )
        LogicError("{A,B} must be distributed over the same grid");
    if( A.Height() != A.Width() )
        LogicError("A must be square");
    if( A.Height() != B.Height() )
        LogicError("A and B must be the same height");
#endif
    if( B.Width() == 1 )
    {
        if( uplo == LOWER )
        {
            if( orientation == TRANSPOSE )
                Conjugate( B );
            Trsv( LOWER, NORMAL, NON_UNIT, A, B );
            Trsv( LOWER, ADJOINT, NON_UNIT, A, B );
            if( orientation == TRANSPOSE )
                Conjugate( B );
        }
        else
        {
            if( orientation == TRANSPOSE )
                Conjugate( B );
            Trsv( UPPER, ADJOINT, NON_UNIT, A, B );
            Trsv( UPPER, NORMAL, NON_UNIT, A, B );
            if( orientation == TRANSPOSE )
                Conjugate( B );
        }
    }
    else
    {
        if( uplo == LOWER )
        {
            if( orientation == TRANSPOSE )
                Conjugate( B );
            Trsm( LEFT, LOWER, NORMAL, NON_UNIT, F(1), A, B );
            Trsm( LEFT, LOWER, ADJOINT, NON_UNIT, F(1), A, B );
            if( orientation == TRANSPOSE )
                Conjugate( B );
        }
        else
        {
            if( orientation == TRANSPOSE )
                Conjugate( B );
            Trsm( LEFT, UPPER, ADJOINT, NON_UNIT, F(1), A, B );
            Trsm( LEFT, UPPER, NORMAL, NON_UNIT, F(1), A, B );
            if( orientation == TRANSPOSE )
                Conjugate( B );
        }
    }
}
开发者ID:khalid-hasanov,项目名称:Elemental,代码行数:57,代码来源:SolveAfter.hpp

示例4: LowerBlocked

void LowerBlocked( Matrix<F>& A, Matrix<F>& householderScalars )
{
    DEBUG_CSE
    const Int n = A.Height();
    householderScalars.Resize( Max(n-1,0), 1 );

    Matrix<F> UB1, V01, VB1, G11;

    const Int bsize = Blocksize();
    for( Int k=0; k<n-1; k+=bsize )
    {
        const Int nb = Min(bsize,n-1-k);

        const Range<Int> ind0( 0,    k    ),
                         ind1( k,    k+nb ),
                         indB( k,    n    ), indR( k, n ),
                         ind2( k+nb, n    );

        auto ABR = A( indB, indR );
        auto A22 = A( ind2, ind2 );

        auto householderScalars1 = householderScalars( ind1, ALL );
        UB1.Resize( n-k, nb );
        VB1.Resize( n-k, nb );
        G11.Resize( nb,  nb );
        hessenberg::LowerPanel( ABR, householderScalars1, UB1, VB1, G11 );

        auto AB0 = A( indB, ind0 );
        auto A2R = A( ind2, indR );
        auto U21 = UB1( IR(nb,END), ALL );
        auto V21 = VB1( IR(nb,END), ALL );

        // AB0 := AB0 - (UB1 inv(G11)^H UB1^H AB0)
        //      = AB0 - (UB1 ((AB0^H UB1) inv(G11))^H)
        // -------------------------------------------
        Gemm( ADJOINT, NORMAL, F(1), AB0, UB1, V01 );
        Trsm( RIGHT, UPPER, NORMAL, NON_UNIT, F(1), G11, V01 );
        Gemm( NORMAL, ADJOINT, F(-1), UB1, V01, F(1), AB0 );

        // A2R := (A2R - U21 inv(G11)^H VB1^H)(I - UB1 inv(G11) UB1^H)
        // -----------------------------------------------------------
        // A2R := A2R - U21 inv(G11)^H VB1^H
        // (note: VB1 is overwritten)
        Trsm( RIGHT, UPPER, NORMAL, NON_UNIT, F(1), G11, VB1 );
        Gemm( NORMAL, ADJOINT, F(-1), U21, VB1, F(1), A2R );
        // A2R := A2R - ((A2R UB1) inv(G11)) UB1^H
        Gemm( NORMAL, NORMAL, F(1), A2R, UB1, F(0), V21 );
        Trsm( RIGHT, UPPER, NORMAL, NON_UNIT, F(1), G11, V21 );
        Gemm( NORMAL, ADJOINT, F(-1), V21, UB1, F(1), A2R );
    }
}
开发者ID:timmoon10,项目名称:Elemental,代码行数:51,代码来源:LowerBlocked.hpp

示例5: PushCallStack

inline void
SolveAfterLU
( Orientation orientation, 
  const DistMatrix<F>& A, const DistMatrix<int,VC,STAR>& p, DistMatrix<F>& B )
{
#ifndef RELEASE
    PushCallStack("SolveAfterLU");
    if( A.Grid() != B.Grid() || A.Grid() != p.Grid() )
        throw std::logic_error("{A,B} must be distributed over the same grid");
    if( A.Height() != A.Width() )
        throw std::logic_error("A must be square");
    if( A.Height() != B.Height() )
        throw std::logic_error("A and B must be the same height");
    if( A.Height() != p.Height() )
        throw std::logic_error("A and p must be the same height");
#endif
    if( B.Width() == 1 )
    {
        if( orientation == NORMAL )
        {
            ApplyRowPivots( B, p );
            Trsv( LOWER, NORMAL, UNIT, A, B );
            Trsv( UPPER, NORMAL, NON_UNIT, A, B );
        }
        else
        {
            Trsv( UPPER, orientation, NON_UNIT, A, B );
            Trsv( LOWER, orientation, UNIT, A, B );
            ApplyInverseRowPivots( B, p );
        }
    }
    else
    {
        if( orientation == NORMAL )
        {
            ApplyRowPivots( B, p );
            Trsm( LEFT, LOWER, NORMAL, UNIT, F(1), A, B );
            Trsm( LEFT, UPPER, NORMAL, NON_UNIT, F(1), A, B );
        }
        else
        {
            Trsm( LEFT, UPPER, orientation, NON_UNIT, F(1), A, B );
            Trsm( LEFT, LOWER, orientation, UNIT, F(1), A, B );
            ApplyInverseRowPivots( B, p );
        }
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
开发者ID:jimgoo,项目名称:Elemental,代码行数:50,代码来源:SolveAfterLU.hpp

示例6: CholeskyUVar2

inline void
CholeskyUVar2( Matrix<F>& A )
{
#ifndef RELEASE
    PushCallStack("hpd_inverse::CholeskyUVar2");
    if( A.Height() != A.Width() )
        throw std::logic_error("Nonsquare matrices cannot be triangular");
#endif
    // Matrix views
    Matrix<F> 
        ATL, ATR,  A00, A01, A02,
        ABL, ABR,  A10, A11, A12,
                   A20, A21, A22;

    // Start the algorithm
    PartitionDownDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    while( ATL.Height() < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
         /*************/ /******************/
               /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        //--------------------------------------------------------------------//
        Cholesky( UPPER, A11 );
        Trsm( RIGHT, UPPER, NORMAL, NON_UNIT, F(1), A11, A01 );
        Trsm( LEFT, UPPER, ADJOINT, NON_UNIT, F(1), A11, A12 );
        Herk( UPPER, NORMAL, F(1), A01, F(1), A00 );
        Gemm( NORMAL, NORMAL, F(-1), A01, A12, F(1), A02 );
        Herk( UPPER, ADJOINT, F(-1), A12, F(1), A22 );
        Trsm( RIGHT, UPPER, ADJOINT, NON_UNIT, F(1), A11, A01 );
        Trsm( LEFT, UPPER, NORMAL, NON_UNIT, F(-1), A11, A12 );
        TriangularInverse( UPPER, NON_UNIT, A11 );
        Trtrmm( ADJOINT, UPPER, A11 );
        //--------------------------------------------------------------------//

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
               /**/       A10, A11, /**/ A12,
         /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
开发者ID:mcg1969,项目名称:Elemental,代码行数:49,代码来源:CholeskyUVar2.hpp

示例7: RLHF

inline void
RLHF( int offset, const Matrix<R>& H, Matrix<R>& A )
{
#ifndef RELEASE
    CallStackEntry entry("apply_packed_reflectors::RLHF");
    if( offset > 0 || offset < -H.Width() )
        throw std::logic_error("Transforms out of bounds");
    if( H.Width() != A.Width() )
        throw std::logic_error
        ("Width of transforms must equal width of target matrix");
#endif
    Matrix<R>
        HTL, HTR,  H00, H01, H02,  HPan, HPanCopy,
        HBL, HBR,  H10, H11, H12,
                   H20, H21, H22;
    Matrix<R> ALeft;

    Matrix<R> SInv, Z;

    LockedPartitionDownDiagonal
    ( H, HTL, HTR,
         HBL, HBR, 0 );
    while( HTL.Height() < H.Height() && HTL.Width() < H.Width() )
    {
        LockedRepartitionDownDiagonal
        ( HTL, /**/ HTR,  H00, /**/ H01, H02,
         /*************/ /******************/
               /**/       H10, /**/ H11, H12,
          HBL, /**/ HBR,  H20, /**/ H21, H22 );

        const int HPanWidth = H10.Width() + H11.Width();
        const int HPanOffset = 
            std::min( H11.Height(), std::max(-offset-H00.Height(),0) );
        const int HPanHeight = H11.Height()-HPanOffset;
        LockedView
        ( HPan, H, H00.Height()+HPanOffset, 0, HPanHeight, HPanWidth );

        View( ALeft, A, 0, 0, A.Height(), HPanWidth );

        //--------------------------------------------------------------------//
        HPanCopy = HPan;
        MakeTrapezoidal( RIGHT, LOWER, offset, HPanCopy );
        SetDiagonal( RIGHT, offset, HPanCopy, R(1) );

        Syrk( UPPER, NORMAL, R(1), HPanCopy, SInv );
        HalveMainDiagonal( SInv );

        Gemm( NORMAL, TRANSPOSE, R(1), ALeft, HPanCopy, Z );
        Trsm( RIGHT, UPPER, NORMAL, NON_UNIT, R(1), SInv, Z );
        Gemm( NORMAL, NORMAL, R(-1), Z, HPanCopy, R(1), ALeft );
        //--------------------------------------------------------------------//

        SlideLockedPartitionDownDiagonal
        ( HTL, /**/ HTR,  H00, H01, /**/ H02,
               /**/       H10, H11, /**/ H12,
         /*************/ /******************/
          HBL, /**/ HBR,  H20, H21, /**/ H22 );
    }
}
开发者ID:ahmadia,项目名称:Elemental-1,代码行数:59,代码来源:RLHF.hpp

示例8: PushCallStack

inline typename Base<F>::type 
LogDetDivergence
( UpperOrLower uplo, const DistMatrix<F>& A, const DistMatrix<F>& B )
{
#ifndef RELEASE
    PushCallStack("LogDetDivergence");
#endif
    if( A.Grid() != B.Grid() )
        throw std::logic_error("A and B must use the same grid");
    if( A.Height() != A.Width() || B.Height() != B.Width() ||
        A.Height() != B.Height() )
        throw std::logic_error
        ("A and B must be square matrices of the same size");

    typedef typename Base<F>::type R;
    const int n = A.Height();
    const Grid& g = A.Grid();

    DistMatrix<F> ACopy( A );
    DistMatrix<F> BCopy( B );

    Cholesky( uplo, ACopy );
    Cholesky( uplo, BCopy );

    if( uplo == LOWER )
    {
        Trtrsm( LEFT, uplo, NORMAL, NON_UNIT, F(1), BCopy, ACopy );
    }
    else
    {
        MakeTrapezoidal( LEFT, uplo, 0, ACopy );
        Trsm( LEFT, uplo, NORMAL, NON_UNIT, F(1), BCopy, ACopy );
    }

    MakeTrapezoidal( LEFT, uplo, 0, ACopy );
    const R frobNorm = Norm( ACopy, FROBENIUS_NORM );

    R logDet;
    R localLogDet(0);
    DistMatrix<F,MD,STAR> d(g);
    ACopy.GetDiagonal( d );
    if( d.InDiagonal() )
    {
        const int nLocalDiag = d.LocalHeight();
        for( int iLocal=0; iLocal<nLocalDiag; ++iLocal )
        {
            const R delta = RealPart(d.GetLocal(iLocal,0));
            localLogDet += 2*Log(delta);
        }
    }
    mpi::AllReduce( &localLogDet, &logDet, 1, mpi::SUM, g.VCComm() );

    const R logDetDiv = frobNorm*frobNorm - logDet - R(n);
#ifndef RELEASE
    PopCallStack();
#endif
    return logDetDiv;
}
开发者ID:certik,项目名称:Elemental,代码行数:58,代码来源:LogDetDivergence.hpp

示例9: Trsm

void BackwardMany
( const DistMatrix<F,VC,STAR>& L,
        DistMatrix<F,VC,STAR>& X,
  bool conjugate=false )
{
    // TODO: Replace this with modified inline code?
    const Orientation orientation = ( conjugate ? ADJOINT : TRANSPOSE );
    Trsm( LEFT, LOWER, orientation, UNIT, F(1), L, X, false, TRSM_SMALL );
}
开发者ID:elemental,项目名称:Elemental,代码行数:9,代码来源:FrontBackward.hpp

示例10: TriangularInverseUVar3

inline void
TriangularInverseUVar3( UnitOrNonUnit diag, Matrix<F>& U )
{
#ifndef RELEASE
    PushCallStack("internal::TriangularInverseUVar3");
    if( U.Height() != U.Width() )
        throw std::logic_error("Nonsquare matrices cannot be triangular");
#endif
    // Matrix views
    Matrix<F> 
        UTL, UTR,  U00, U01, U02,
        UBL, UBR,  U10, U11, U12,
                   U20, U21, U22;

    // Start the algorithm
    PartitionUpDiagonal
    ( U, UTL, UTR,
         UBL, UBR, 0 );
    while( UBR.Height() < U.Height() )
    {
        RepartitionUpDiagonal
        ( UTL, /**/ UTR,  U00, U01, /**/ U02,
               /**/       U10, U11, /**/ U12,
         /*************/ /******************/
          UBL, /**/ UBR,  U20, U21, /**/ U22 );

        //--------------------------------------------------------------------//
        Trsm( RIGHT, UPPER, NORMAL, diag, F(-1), U11, U01 );
        Gemm( NORMAL, NORMAL, F(1), U01, U12, F(1), U02 );
        Trsm( LEFT, UPPER, NORMAL, diag, F(1), U11, U12 );
        TriangularInverseUVar3Unb( diag, U11 );
        //--------------------------------------------------------------------//

        SlidePartitionUpDiagonal
        ( UTL, /**/ UTR,  U00, /**/ U01, U02,
         /*************/ /******************/
               /**/       U10, /**/ U11, U12,
          UBL, /**/ UBR,  U20, /**/ U21, U22 );
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
开发者ID:jimgoo,项目名称:Elemental,代码行数:43,代码来源:UVar3.hpp

示例11: RunRoutine

 // Describes how to run the CLBlast routine
 static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
   auto queue_plain = queue();
   auto event = cl_event{};
   auto status = Trsm(args.layout, args.side, args.triangle, args.a_transpose, args.diagonal,
                      args.m, args.n, args.alpha,
                      buffers.a_mat(), args.a_offset, args.a_ld,
                      buffers.b_mat(), args.b_offset, args.b_ld,
                      &queue_plain, &event);
   if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
   return status;
 }
开发者ID:dividiti,项目名称:CLBlast,代码行数:12,代码来源:xtrsm.hpp

示例12: LU

inline void
LU( Matrix<F>& A )
{
#ifndef RELEASE
    PushCallStack("LU");
#endif
    // Matrix views
    Matrix<F>
        ATL, ATR,  A00, A01, A02, 
        ABL, ABR,  A10, A11, A12,  
                   A20, A21, A22;

    // Start the algorithm
    PartitionDownDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    while( ATL.Height() < A.Height() && ATL.Width() < A.Width() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
         /*************/ /******************/
               /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        //--------------------------------------------------------------------//
        internal::LUUnb( A11 );
        Trsm( RIGHT, UPPER, NORMAL, NON_UNIT, F(1), A11, A21 );
        Trsm( LEFT, LOWER, NORMAL, UNIT, F(1), A11, A12 );
        Gemm( NORMAL, NORMAL, F(-1), A21, A12, F(1), A22 );
        //--------------------------------------------------------------------//

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
               /**/       A10, A11, /**/ A12,
         /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
开发者ID:jimgoo,项目名称:Elemental,代码行数:41,代码来源:LU.hpp

示例13: SolveAfterLU

inline void
SolveAfterLU( Orientation orientation, const Matrix<F>& A, Matrix<F>& B )
{
#ifndef RELEASE
    PushCallStack("SolveAfterLU");
    if( A.Height() != A.Width() )
        throw std::logic_error("A must be square");
    if( A.Height() != B.Height() )
        throw std::logic_error("A and B must be the same height");
#endif
    if( B.Width() == 1 )
    {
        if( orientation == NORMAL )
        {
            Trsv( LOWER, NORMAL, UNIT, A, B );
            Trsv( UPPER, NORMAL, NON_UNIT, A, B );
        }
        else 
        {
            Trsv( UPPER, orientation, NON_UNIT, A, B );
            Trsv( LOWER, orientation, UNIT, A, B );
        }
    }
    else
    {
        if( orientation == NORMAL )
        {
            Trsm( LEFT, LOWER, NORMAL, UNIT, F(1), A, B );
            Trsm( LEFT, UPPER, NORMAL, NON_UNIT, F(1), A, B );
        }
        else
        {
            Trsm( LEFT, UPPER, orientation, NON_UNIT, F(1), A, B );
            Trsm( LEFT, LOWER, orientation, UNIT, F(1), A, B );
        }
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
开发者ID:jimgoo,项目名称:Elemental,代码行数:40,代码来源:SolveAfterLU.hpp

示例14: GaussianElimination

inline void
GaussianElimination( Matrix<F>& A, Matrix<F>& B )
{
#ifndef RELEASE
    CallStackEntry entry("GaussianElimination");
    if( A.Height() != A.Width() )
        LogicError("A must be square");
    if( A.Height() != B.Height() )
        LogicError("A and B must be the same height");
#endif
    RowEchelon( A, B );
    if( B.Width() == 1 )
        Trsv( UPPER, NORMAL, NON_UNIT, A, B );
    else
        Trsm( LEFT, UPPER, NORMAL, NON_UNIT, F(1), A, B );
}
开发者ID:khalid-hasanov,项目名称:Elemental,代码行数:16,代码来源:GaussianElimination.hpp

示例15: LogDetDivergence

inline typename Base<F>::type 
LogDetDivergence( UpperOrLower uplo, const Matrix<F>& A, const Matrix<F>& B )
{
#ifndef RELEASE
    PushCallStack("LogDetDivergence");
#endif
    if( A.Height() != A.Width() || B.Height() != B.Width() ||
        A.Height() != B.Height() )
        throw std::logic_error
        ("A and B must be square matrices of the same size");

    typedef typename Base<F>::type R;
    const int n = A.Height();

    Matrix<F> ACopy( A );
    Matrix<F> BCopy( B );

    Cholesky( uplo, ACopy );
    Cholesky( uplo, BCopy );

    if( uplo == LOWER )
    {
        Trtrsm( LEFT, uplo, NORMAL, NON_UNIT, F(1), BCopy, ACopy );
    }
    else
    {
        MakeTrapezoidal( LEFT, uplo, 0, ACopy );
        Trsm( LEFT, uplo, NORMAL, NON_UNIT, F(1), BCopy, ACopy );
    }

    MakeTrapezoidal( LEFT, uplo, 0, ACopy );
    const R frobNorm = Norm( ACopy, FROBENIUS_NORM );

    Matrix<F> d;
    ACopy.GetDiagonal( d );
    R logDet(0);
    for( int i=0; i<n; ++i )
        logDet += 2*Log( RealPart(d.Get(i,0)) );

    const R logDetDiv = frobNorm*frobNorm - logDet - R(n);
#ifndef RELEASE
    PopCallStack();
#endif
    return logDetDiv;
}
开发者ID:certik,项目名称:Elemental,代码行数:45,代码来源:LogDetDivergence.hpp


注:本文中的Trsm函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。