当前位置: 首页>>代码示例>>Python>>正文


Python DataManager.continueTrainingLoop方法代码示例

本文整理汇总了Python中DataManager.DataManager.continueTrainingLoop方法的典型用法代码示例。如果您正苦于以下问题:Python DataManager.continueTrainingLoop方法的具体用法?Python DataManager.continueTrainingLoop怎么用?Python DataManager.continueTrainingLoop使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在DataManager.DataManager的用法示例。


在下文中一共展示了DataManager.continueTrainingLoop方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: mainTF

# 需要导入模块: from DataManager import DataManager [as 别名]
# 或者: from DataManager.DataManager import continueTrainingLoop [as 别名]

#.........这里部分代码省略.........
  ptps[ptps < 1e-10] = 1.0

  ##Create data manager, this class controls how data is fed to the network for training
  #                 DataSet(fileGlob, xsec, Nevts, kFactor, sig, prescale, rescale)
  signalDataSets = [DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6p1/trainingTuple_*_division_0_TTbarSingleLepT_training_*.h5",      365.4,  61878989, 1.0, True,  0, 1.0, 1.0, 8),
                    DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6p1/trainingTuple_*_division_0_TTbarSingleLepTbar_training_*.h5",   365.4,  61901450, 1.0, True,  0, 1.0, 1.0, 8),]

  #pt reweighting histograms 
  ttbarRatio = (numpy.array([0.7976347,  1.010679,  1.0329635,  1.0712056,  1.1147588,  1.0072196,  0.79854023, 0.7216115,  0.7717652,  0.851551,   0.8372917 ]), numpy.array([  0.,  50., 100., 150., 200., 250., 300., 350., 400., 450., 500., 1e10]))
  QCDDataRatio = (numpy.array([0.50125164, 0.70985824, 1.007087,   1.6701245,  2.5925348,  3.6850858, 4.924969,   6.2674766,  7.5736594,  8.406105,   7.7529635 ]), numpy.array([  0.,  50., 100., 150., 200., 250., 300., 350., 400., 450., 500., 1e10]))
  QCDMCRatio = (numpy.array([0.75231355, 1.0563549,  1.2571484,  1.3007764,  1.0678109,  0.83444154, 0.641499,   0.49130705, 0.36807108, 0.24333349, 0.06963781]), numpy.array([  0.,  50., 100., 150., 200., 250., 300., 350., 400., 450., 500., 1e10]))

  backgroundDataSets = [DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6/trainingTuple_*_division_0_TTbarSingleLepT_training_*.h5",    365.4,  61878989, 1.0, False, 0, 1.0, 1.0, 8,  ttbarRatio),
                        DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6/trainingTuple_*_division_0_TTbarSingleLepTbar_training_*.h5", 365.4,  61901450, 1.0, False, 0, 1.0, 1.0, 8,  ttbarRatio),
                        DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6/trainingTuple_*_division_0_Data_JetHT_2016_training_*.h5",      1.0,         1, 1.0, False, 1, 1.0, 1.0, 8,  include = False),#QCDDataRatio),
                        DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6/trainingTuple_*_division_0_QCD_HT100to200_training_*.h5",   27990000,  80684349, 0.0, False, 2, 1.0, 1.0, 1, include = False),#QCDMCRatio), 
                        DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6/trainingTuple_*_division_0_QCD_HT200to300_training_*.h5",   1712000 ,  57580393, 0.0, False, 2, 1.0, 1.0, 1, include = False),#QCDMCRatio),
                        DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6/trainingTuple_*_division_0_QCD_HT300to500_training_*.h5",   347700  ,  54537903, 0.0, False, 2, 1.0, 1.0, 1, include = False),#QCDMCRatio),
                        DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6/trainingTuple_*_division_0_QCD_HT500to700_training_*.h5",   32100   ,  62271343, 0.0, False, 2, 1.0, 1.0, 1, include = False),#QCDMCRatio),
                        DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6/trainingTuple_*_division_0_QCD_HT700to1000_training_*.h5",  6831    ,  45232316, 0.0, False, 2, 1.0, 1.0, 1, include = False),#QCDMCRatio),
                        DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6/trainingTuple_*_division_0_QCD_HT1000to1500_training_*.h5", 1207    ,  15127293, 0.0, False, 2, 1.0, 1.0, 1, include = False),#QCDMCRatio),
                        DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6/trainingTuple_*_division_0_QCD_HT1500to2000_training_*.h5", 119.9   ,  11826702, 0.0, False, 2, 1.0, 1.0, 1, include = False),#QCDMCRatio),
                        DataSet("/cms/data/pastika/trainData_pt20_30_40_dRPi_tightMass_deepFlavor_v6/trainingTuple_*_division_0_QCD_HT2000toInf_training_*.h5",  25.24   ,   6039005, 0.0, False, 2, 1.0, 1.0, 1, include = False),#QCDMCRatio),
                        ]

  dm = DataManager(options.netOp.vNames, nEpoch, nFeatures, nLabels, 2, nWeights, options.runOp.ptReweight, signalDataSets, backgroundDataSets)

  # Build the graph
  denseNetwork = [nFeatures]+options.netOp.denseLayers+[nLabels]
  convLayers = options.netOp.convLayers
  rnnNodes = options.netOp.rnnNodes
  rnnLayers = options.netOp.rnnLayers
  mlp = CreateModel(options, denseNetwork, convLayers, rnnNodes, rnnLayers, dm.inputDataQueue, MiniBatchSize, mins, 1.0/ptps)

  #summary writer
  summary_writer = tf.summary.FileWriter(options.runOp.directory + "log_graph", graph=tf.get_default_graph())

  print "TRAINING NETWORK"

  with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=8) ) as sess:
    sess.run(tf.global_variables_initializer())

    #start queue runners
    dm.launchQueueThreads(sess)

    print "Reporting validation loss every %i batches with %i events per batch for %i epochs"%(ReportInterval, MiniBatchSize, nEpoch)

    #preload the first data into staging area
    sess.run([mlp.stagingOp], feed_dict={mlp.reg: l2Reg, mlp.keep_prob:options.runOp.keepProb})

    i = 0
    N_TRAIN_SUMMARY = 10

    #flush queue until the sample fraction is approximately equal 
    flushctr = 200
    while dm.continueTrainingLoop():
      result = sess.run(dm.inputDataQueue.dequeue_many(MiniBatchSize))
      signalCount = result[1][:,0].sum()
      bgCount = result[1][:,1].sum()
      signalFraction =  signalCount/(signalCount+bgCount)
      #the first this fraction drops below 0.5 means we are close enough to equal signal/bg fraction 
      if signalFraction < 0.5:
        flushctr -= 1
        if flushctr <= 0:
          break

    try:
      while dm.continueTrainingLoop():

        grw = 2/(1+exp(-i/10000.0)) - 1

        #run validation operations 
        if i == 0 or not i % ReportInterval:
          #run validation operations 
          validation_loss, accuracy, summary_vl = sess.run([mlp.loss_ph, mlp.accuracy, mlp.merged_valid_summary_op], feed_dict={mlp.x_ph: validDataTTbar["data"][:validationCount], mlp.y_ph_: validDataTTbar["labels"][:validationCount], mlp.p_ph_: validDataTTbar["domain"][:validationCount], mlp.reg: l2Reg, mlp.gradientReversalWeight:grw, mlp.wgt_ph: validDataTTbar["weights"][:validationCount]})
          summary_writer.add_summary(summary_vl, i/N_TRAIN_SUMMARY)

          print('Interval %d, validation accuracy %0.6f, validation loss %0.6f' % (i/ReportInterval, accuracy, validation_loss))

          validation_loss, accuracy, summary_vl_QCDMC = sess.run([mlp.loss_ph, mlp.accuracy, mlp.merged_valid_QCDMC_summary_op], feed_dict={mlp.x_ph: validDataQCDMC["data"][:validationCount], mlp.y_ph_: validDataQCDMC["labels"][:validationCount], mlp.p_ph_: validDataQCDMC["domain"][:validationCount], mlp.reg: l2Reg, mlp.gradientReversalWeight:grw, mlp.wgt_ph: validDataQCDMC["weights"][:validationCount]})
          summary_writer.add_summary(summary_vl_QCDMC, i/N_TRAIN_SUMMARY)

          validation_loss, accuracy, summary_vl_QCDData = sess.run([mlp.loss_ph, mlp.accuracy, mlp.merged_valid_QCDData_summary_op], feed_dict={mlp.x_ph: validDataQCDData["data"][:validationCount], mlp.y_ph_: validDataQCDData["labels"][:validationCount], mlp.p_ph_: validDataQCDData["domain"][:validationCount], mlp.reg: l2Reg, mlp.gradientReversalWeight:grw, mlp.wgt_ph: validDataQCDData["weights"][:validationCount]})
          summary_writer.add_summary(summary_vl_QCDData, i/N_TRAIN_SUMMARY)

        #run training operations 
        if i % N_TRAIN_SUMMARY == 0:
          _, _, summary = sess.run([mlp.stagingOp, mlp.train_step, mlp.merged_train_summary_op], feed_dict={mlp.reg: l2Reg, mlp.keep_prob:options.runOp.keepProb, mlp.training: True, mlp.gradientReversalWeight:grw})
          summary_writer.add_summary(summary, i/N_TRAIN_SUMMARY)
        else:
          sess.run([mlp.stagingOp, mlp.train_step], feed_dict={mlp.reg: l2Reg, mlp.keep_prob:options.runOp.keepProb, mlp.training: True})
        i += 1

      while dm.continueFlushingQueue():
        sess.run(dm.inputDataQueue.dequeue_many(MiniBatchSize))

    except Exception, e:
      # Report exceptions to the coordinator.
      dm.requestStop(e)
    finally:
开发者ID:susy2015,项目名称:TopTagger,代码行数:104,代码来源:MVATrainers.py


注:本文中的DataManager.DataManager.continueTrainingLoop方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。