当前位置: 首页>>代码示例>>Python>>正文


Python Cluster.train方法代码示例

本文整理汇总了Python中cluster.Cluster.train方法的典型用法代码示例。如果您正苦于以下问题:Python Cluster.train方法的具体用法?Python Cluster.train怎么用?Python Cluster.train使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在cluster.Cluster的用法示例。


在下文中一共展示了Cluster.train方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: from cluster import Cluster [as 别名]
# 或者: from cluster.Cluster import train [as 别名]
def train(config):
  '''Tain Loop for TDM Algorithm'''

  train_rawdata_url = config["train_rawdata_url"]
  test_rawdata_url = config["test_rawdata_url"]
  data_dir = config['data_dir']
  raw_train_data = os.path.join(data_dir, train_rawdata_url.split('/')[-1])
  raw_test_data = os.path.join(data_dir, test_rawdata_url.split('/')[-1])
  tree_filename = os.path.join(data_dir, config['tree_filename'])
  train_sample = os.path.join(data_dir, config['train_sample'])
  test_sample = os.path.join(data_dir, config['test_sample'])
  stat_file = os.path.join(data_dir, config['stat_file'])

  print("Start to generating initialization data")
  # Download the raw data
  hdfs_download(train_rawdata_url, raw_train_data)
  hdfs_download(test_rawdata_url, raw_test_data)

  generator = Generator(raw_train_data,
                        raw_test_data,
                        tree_filename,
                        train_sample,
                        test_sample,
                        config['feature_conf'],
                        stat_file,
                        config['seq_len'],
                        config['min_seq_len'],
                        config['parall'],
                        config['train_id_label'],
                        config['test_id_label'])
  generator.generate()

  # Upload generating data to hdfs
  hdfs_upload(data_dir, config["upload_url"])

  # TDM train
  model_embed = os.path.join(data_dir, 'model.embed')
  tree_upload_dir = os.path.join(config['upload_url'], os.path.split(data_dir)[-1])
  for i in range(config['epocs']):
    print('Training, iteration: {iteration}'.format(iteration=i))

    # TODO(genbao.cgb): Train with xdl

    # Download the model file
    hdfs_download(config['model_url'], model_embed)

    # Tree clustering
    cluster = Cluster(model_embed, tree_filename,
                      parall=config['parall'], stat_file=stat_file)
    cluster.train()

    # Upload clustered tree to hdfs
    hdfs_upload(tree_filename, tree_upload_dir, over_write=True)
开发者ID:q64545,项目名称:x-deeplearning,代码行数:55,代码来源:tdm.py

示例2: train

# 需要导入模块: from cluster import Cluster [as 别名]
# 或者: from cluster.Cluster import train [as 别名]
def train(config):
  '''Tain Loop for TDM Algorithm'''

  data_dir = os.path.join(DIR, config['data_dir'])
  tree_filename = os.path.join(data_dir, config['tree_filename'])
  stat_file = os.path.join(data_dir, config['stat_file'])

  print("Start to cluster tree")
  # Download item id
  upload_dir = os.path.join(config['upload_url'], os.path.split(data_dir)[-1])
  item_id_url = os.path.join(upload_dir, config['item_id_file'])
  item_id_file = os.path.join(data_dir, 'item.id')
  hdfs_download(item_id_url, item_id_file)
  model_embed_tmp = os.path.join(data_dir, 'model.embed.tmp')
  hdfs_download(config['model_url'] + '/item_emb', model_embed_tmp)

  # Read max item id from item id file
  max_item_id = 0
  with open(item_id_file) as f:
    for line in f:
      item_id = int(line)
      if item_id > max_item_id:
        max_item_id = item_id
  max_item_id += 1

  model_embed = os.path.join(data_dir, 'model.embed')
  item_count = 0
  id_set = set()
  with open(model_embed_tmp) as f:
    with open(model_embed, 'wb') as fo:
      for line in f:
        arr = line.split(",")
        item_id = int(arr[0])
        if (len(arr) > 2) and (item_id < max_item_id) and (item_id not in id_set):
          id_set.add(item_id)
          item_count += 1
          fo.write(line)

  os.remove(model_embed_tmp)
  print("Filer embedding done, records:{}, max_leaf_id: {}".format(
      item_count, max_item_id))

  # Tree clustering
  cluster = Cluster(model_embed, tree_filename,
                    parall=config['parall'], stat_file=stat_file)
  cluster.train()

  # Upload clustered tree to hdfs
  tree_upload_dir = os.path.join(config['upload_url'], os.path.split(data_dir)[-1])
  hdfs_upload(tree_filename, tree_upload_dir, over_write=True)
开发者ID:q64545,项目名称:x-deeplearning,代码行数:52,代码来源:tree_cluster.py


注:本文中的cluster.Cluster.train方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。