当前位置: 首页>>代码示例>>C++>>正文


C++ Forest::GetWeights方法代码示例

本文整理汇总了C++中Forest::GetWeights方法的典型用法代码示例。如果您正苦于以下问题:C++ Forest::GetWeights方法的具体用法?C++ Forest::GetWeights怎么用?C++ Forest::GetWeights使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在Forest的用法示例。


在下文中一共展示了Forest::GetWeights方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。

示例1: total_energy

  inline double total_energy()
  {

    double energy = 0.0;
    double *Dp = D;
    for ( int m=0; m<M; m++ ) {
      energy += norm2( Dp, LabelSet::classes );
      Dp += LabelSet::classes;
    }

    double t0[LabelSet::classes];

    for ( int l=0; l<L; l++ ) {
      auto& _to_j = forest->GetWeights( l );
      for ( auto& ele : _to_j ) {
        int j = ele.first;
        if ( j<l ) {
          double wt = static_cast<double>( ele.second );
          minus( q + l * LabelSet::classes, q + j * LabelSet::classes, t0, LabelSet::classes );
          energy += options.beta * wt * norm2( t0, LabelSet::classes );
        }
      }
    }

    return energy;
    
  }
开发者ID:breakds,项目名称:PatTk,代码行数:27,代码来源:TrainLabel.cpp

示例2: restrict_energy

  /*
   * Restricted Energy on q(l) is
   *     sum_m ( D(m) - alpha(l,m) * q(l) + alpha(l,m) * q'(l) )^2
   *   + sum_j w(i,j) * ( q'(l) - q(j) )^2
   */
  inline double restrict_energy( int l, double *q_l = nullptr )
  {

    auto& _to_m = m_to_l->getFromSet( l );
    auto& _to_j = forest->GetWeights( l );
    
    double energy = 0.0;
    double t0[LabelSet::classes];

    if ( nullptr == q_l ) {
      for ( auto& ele : _to_m ) {
        int m = ele.first;
        energy += norm2( D + m * LabelSet::classes, LabelSet::classes );
      }
      
      for ( auto& ele : _to_j ) {
        int j = ele.first;
        double wt = static_cast<double>( ele.second );

        minus( q + l * LabelSet::classes, q + j * LabelSet::classes, t0, LabelSet::classes );
        energy += options.beta * wt * norm2( t0, LabelSet::classes );
      }
    } else {

      double t0[LabelSet::classes];

      for ( auto& ele : _to_m ) {
        int m = ele.first;
        double alpha = ele.second;

        memcpy( t0, D + m * LabelSet::classes, sizeof(double) * LabelSet::classes );
        minusScaledFrom( t0, q + l * LabelSet::classes, LabelSet::classes, alpha );
        addScaledTo( t0, q_l, LabelSet::classes, alpha );
        energy += norm2( t0, LabelSet::classes );
      }

      
      for ( auto& ele : _to_j ) {
        int j = ele.first;
        double wt = static_cast<double>( ele.second );
        minus( q_l, q + j * LabelSet::classes, t0, LabelSet::classes );
        energy += options.beta * wt * norm2( t0, LabelSet::classes );
      }

    }

    return energy;

  }
开发者ID:breakds,项目名称:PatTk,代码行数:54,代码来源:TrainLabel.cpp

示例3: update_q

  inline void update_q( int l )
  {

    auto& _to_m = m_to_l->getFromSet( l );
    auto& _to_j = forest->GetWeights( l );
    
    double t0[LabelSet::classes];
    memset( t0, 0, sizeof(double) * LabelSet::classes );
    double t1[LabelSet::classes];
        
    


    // t0 = sum_m alpha(l,m) * D(m) 
    for ( auto& ele : _to_m ) {
      int m = ele.first;
      double alpha = ele.second;
      addScaledTo( t0, D + m * LabelSet::classes, LabelSet::classes, alpha );
    }


    
    // t0 += sum_m wt(l,j) (q(l) - q(j) )
    for ( auto& ele : _to_j ) {
      int j = ele.first;
      double wt = static_cast<double>( ele.second );
      minus( q + l * LabelSet::classes, q + j * LabelSet::classes, t1, LabelSet::classes );
      addScaledTo( t0, t1, LabelSet::classes, wt );
    }

    // negate t0 to get negative gradient direction
    negate( t0, LabelSet::classes );
    
    // Line Search
    double energy_old = restrict_energy( l ) * options.wolf;

    bool updated = false;

    
    normalize_vec( t0, t0, LabelSet::classes );
    double energy_new = 0.0;
    for ( int i=0; i<40; i++ ) {
      scale( t0, LabelSet::classes, options.shrinkRatio );
      add( t0, q + l * LabelSet::classes, t1, LabelSet::classes );
      // Simplex Projection
      watershed( t1, t0, LabelSet::classes );

      energy_new = restrict_energy( l, t0 );
      if ( energy_new < energy_old ) {
        updated = true;
        break;
      }
    }

    if ( updated ) {
      for ( auto& ele : _to_m ) {
        int m = ele.first;
        double alpha = ele.second;
        update_D( m, l, t0, alpha );
      }
      memcpy( q + l * LabelSet::classes, t0, sizeof(double) * LabelSet::classes );
    }
  }
开发者ID:breakds,项目名称:PatTk,代码行数:63,代码来源:TrainLabel.cpp


注:本文中的Forest::GetWeights方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。