当前位置: 首页>>代码示例>>Python>>正文


Python tools.IteratorTimer方法代码示例

本文整理汇总了Python中utils.tools.IteratorTimer方法的典型用法代码示例。如果您正苦于以下问题:Python tools.IteratorTimer方法的具体用法?Python tools.IteratorTimer怎么用?Python tools.IteratorTimer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.tools的用法示例。


在下文中一共展示了tools.IteratorTimer方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test

# 需要导入模块: from utils import tools [as 别名]
# 或者: from utils.tools import IteratorTimer [as 别名]
def test(args, epoch, model, data_loader):
    """
    TESTING PROCEDURE

    Parameters:
    -----------
        - args: various arguments
        - epoch: number of epochs 
        - model: specified model to test
        - data_loader: specified test data_loader

    Returns:
    --------
        - average_loss: average loss per batch
        - pck: Percentage of Correct Keypoints metric

    """
    
    statistics = []
    total_loss = 0

    model.eval()
    title = 'Validating Epoch {}'.format(epoch)
    progress = tqdm(tools.IteratorTimer(data_loader), ncols=120, total=len(data_loader), smoothing=.9, miniters=1, leave=True, desc=title)
    predictions = []
    gt = []

    sys.stdout.flush()
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(progress):

            d = model(data[0].to(args.device), im_2 = data[1].to(args.device))
            loss = _apply_loss(d, target).mean()
            total_loss += loss.item()
            predictions.extend(d.numpy())
            gt.extend(target.numpy())

            # Print out statistics
            statistics.append(loss.item())
            title = '{} Epoch {}'.format('Validating', epoch)

            progress.set_description(title + '\tLoss:\t'+ str(statistics[-1]))
            sys.stdout.flush()


    progress.close()
    pck = tools.calc_pck(np.asarray(predictions), np.asarray(gt))
    print('PCK for epoch %d is %f'%(epoch, pck))

    return total_loss / float(batch_idx + 1), pck 
开发者ID:yannadani,项目名称:dlgm,代码行数:52,代码来源:main.py

示例2: train

# 需要导入模块: from utils import tools [as 别名]
# 或者: from utils.tools import IteratorTimer [as 别名]
def train(args, epoch, model, data_loader, optimizer):
    """
    TRAINING PROCEDURE

     Parameters:
    -----------
        - args: various arguments
        - epoch: number of epochs 
        - model: specified model to test
        - data_loader: specified train data_loader
        - optimizer: specified optimizer to use

    Returns:
    --------
        - average_loss: average loss per batch

    """
    
    statistics = []
    total_loss = 0

    model.train()
    title = 'Training Epoch {}'.format(epoch)
    progress = tqdm(tools.IteratorTimer(data_loader), ncols=120, total=len(data_loader), smoothing=.9, miniters=1, leave=True, desc=title)

    sys.stdout.flush()

    for batch_idx, (data, target) in enumerate(progress):

        #data, target = data.to(args.device), target.to(args.device)

        optimizer.zero_grad()
        d = model(data[0].to(args.device), im_2 = data[1].to(args.device))
        loss = _apply_loss(d, target).mean()

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        assert not np.isnan(total_loss)

        # Print out statistics
        statistics.append(loss.item())
        title = '{} Epoch {}'.format('Training', epoch)

        progress.set_description(title + '\tLoss:\t'+ str(statistics[-1]))
        sys.stdout.flush()


    progress.close()

    return total_loss / float(batch_idx + 1)

# ====================================================================================================================================
# MAIN PROCEDURE
# ========================= 
开发者ID:yannadani,项目名称:dlgm,代码行数:57,代码来源:main.py


注:本文中的utils.tools.IteratorTimer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。