当前位置: 首页>>代码示例>>Python>>正文


Python utils.train方法代码示例

本文整理汇总了Python中utils.train方法的典型用法代码示例。如果您正苦于以下问题:Python utils.train方法的具体用法?Python utils.train怎么用?Python utils.train使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils的用法示例。


在下文中一共展示了utils.train方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train_fine_tuning

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def train_fine_tuning(net, optimizer, batch_size=128, num_epochs=4):
    train_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs), batch_size, shuffle=True)
    test_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs), batch_size)
    loss = torch.nn.CrossEntropyLoss()
    utils.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs) 
开发者ID:wdxtub,项目名称:deep-learning-note,代码行数:7,代码来源:48_fine_tune_hotdog.py

示例2: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def main():
    args = make_args()
    config = configparser.ConfigParser()
    utils.load_config(config, args.config)
    for cmd in args.modify:
        utils.modify_config(config, cmd)
    with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
        logging.config.dictConfig(yaml.load(f))
    model_dir = utils.get_model_dir(config)
    path, step, epoch = utils.train.load_model(model_dir)
    state_dict = torch.load(path, map_location=lambda storage, loc: storage)
    mapper = [(inflection.underscore(name), member()) for name, member in inspect.getmembers(importlib.machinery.SourceFileLoader('', __file__).load_module()) if inspect.isclass(member)]
    path = os.path.join(model_dir, os.path.basename(os.path.splitext(__file__)[0])) + '.xlsx'
    with xlsxwriter.Workbook(path, {'strings_to_urls': False, 'nan_inf_to_errors': True}) as workbook:
        worksheet = workbook.add_worksheet(args.worksheet)
        for j, (key, m) in enumerate(mapper):
            worksheet.write(0, j, key)
            for i, (name, variable) in enumerate(state_dict.items()):
                value = m(name, variable)
                worksheet.write(1 + i, j, value)
            if hasattr(m, 'format'):
                m.format(workbook, worksheet, i, j)
        worksheet.autofilter(0, 0, i, len(mapper) - 1)
        worksheet.freeze_panes(1, 0)
    logging.info(path) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:27,代码来源:variable_stat.py

示例3: __init__

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def __init__(self, env):
        super(SummaryWorker, self).__init__()
        self.env = env
        self.config = env.config
        self.queue = multiprocessing.Queue()
        try:
            self.timer_scalar = utils.train.Timer(env.config.getfloat('summary', 'scalar'))
        except configparser.NoOptionError:
            self.timer_scalar = lambda: False
        try:
            self.timer_image = utils.train.Timer(env.config.getfloat('summary', 'image'))
        except configparser.NoOptionError:
            self.timer_image = lambda: False
        try:
            self.timer_histogram = utils.train.Timer(env.config.getfloat('summary', 'histogram'))
        except configparser.NoOptionError:
            self.timer_histogram = lambda: False
        with open(os.path.expanduser(os.path.expandvars(env.config.get('summary_histogram', 'parameters'))), 'r') as f:
            self.histogram_parameters = utils.RegexList([line.rstrip() for line in f])
        self.draw_bbox = utils.visualize.DrawBBox(env.category)
        self.draw_feature = utils.visualize.DrawFeature() 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:23,代码来源:train.py

示例4: get_loader

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def get_loader(self):
        paths = [os.path.join(self.cache_dir, phase + '.pkl') for phase in self.config.get('train', 'phase').split()]
        dataset = utils.data.Dataset(
            utils.data.load_pickles(paths),
            transform=transform.augmentation.get_transform(self.config, self.config.get('transform', 'augmentation').split()),
            one_hot=None if self.config.getboolean('train', 'cross_entropy') else len(self.category),
            shuffle=self.config.getboolean('data', 'shuffle'),
            dir=os.path.join(self.model_dir, 'exception'),
        )
        logging.info('num_examples=%d' % len(dataset))
        try:
            workers = self.config.getint('data', 'workers')
            if torch.cuda.is_available():
                workers = workers * torch.cuda.device_count()
        except configparser.NoOptionError:
            workers = multiprocessing.cpu_count()
        collate_fn = utils.data.Collate(
            transform.parse_transform(self.config, self.config.get('transform', 'resize_train')),
            utils.train.load_sizes(self.config),
            maintain=self.config.getint('data', 'maintain'),
            transform_image=transform.get_transform(self.config, self.config.get('transform', 'image_train').split()),
            transform_tensor=transform.get_transform(self.config, self.config.get('transform', 'tensor').split()),
            dir=os.path.join(self.model_dir, 'exception'),
        )
        return torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size * torch.cuda.device_count() if torch.cuda.is_available() else self.args.batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn, pin_memory=torch.cuda.is_available()) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:27,代码来源:train.py

示例5: iterate

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def iterate(self, data):
        for key in data:
            t = data[key]
            if torch.is_tensor(t):
                data[key] = utils.ensure_device(t)
        tensor = torch.autograd.Variable(data['tensor'])
        pred = pybenchmark.profile('inference')(model._inference)(self.inference, tensor)
        height, width = data['image'].size()[1:3]
        rows, cols = pred['feature'].size()[-2:]
        loss, debug = pybenchmark.profile('loss')(model.loss)(self.anchors, norm_data(data, height, width, rows, cols), pred, self.config.getfloat('model', 'threshold'))
        loss_hparam = {key: loss[key] * self.config.getfloat('hparam', key) for key in loss}
        loss_total = sum(loss_hparam.values())
        self.optimizer.zero_grad()
        loss_total.backward()
        try:
            clip = self.config.getfloat('train', 'clip')
            nn.utils.clip_grad_norm(self.inference.parameters(), clip)
        except configparser.NoOptionError:
            pass
        self.optimizer.step()
        return dict(
            height=height, width=width, rows=rows, cols=cols,
            data=data, pred=pred, debug=debug,
            loss_total=loss_total, loss=loss, loss_hparam=loss_hparam,
        ) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:27,代码来源:train.py

示例6: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def main():
    args = make_args()
    config = configparser.ConfigParser()
    utils.load_config(config, args.config)
    for cmd in args.modify:
        utils.modify_config(config, cmd)
    with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
        logging.config.dictConfig(yaml.load(f))
    if args.run is None:
        buffer = io.StringIO()
        config.write(buffer)
        args.run = hashlib.md5(buffer.getvalue().encode()).hexdigest()
    logging.info('cd ' + os.getcwd() + ' && ' + subprocess.list2cmdline([sys.executable] + sys.argv))
    train = Train(args, config)
    train()
    logging.info(pybenchmark.stats) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:18,代码来源:train.py

示例7: __init__

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def __init__(self, env):
        super(SummaryWorker, self).__init__()
        self.env = env
        self.config = env.config
        self.queue = multiprocessing.Queue()
        try:
            self.timer_scalar = utils.train.Timer(env.config.getfloat('summary', 'scalar'))
        except configparser.NoOptionError:
            self.timer_scalar = lambda: False
        try:
            self.timer_image = utils.train.Timer(env.config.getfloat('summary', 'image'))
        except configparser.NoOptionError:
            self.timer_image = lambda: False
        try:
            self.timer_histogram = utils.train.Timer(env.config.getfloat('summary', 'histogram'))
        except configparser.NoOptionError:
            self.timer_histogram = lambda: False
        with open(os.path.expanduser(os.path.expandvars(env.config.get('summary_histogram', 'parameters'))), 'r') as f:
            self.histogram_parameters = utils.RegexList([line.rstrip() for line in f])
        self.draw_points = utils.visualize.DrawPoints(env.limbs_index, colors=env.config.get('draw_points', 'colors').split())
        self._draw_points = utils.visualize.DrawPoints(env.limbs_index, thickness=1)
        self.draw_bbox = utils.visualize.DrawBBox()
        self.draw_feature = utils.visualize.DrawFeature()
        self.draw_cluster = utils.visualize.DrawCluster() 
开发者ID:ruiminshen,项目名称:openpose-pytorch,代码行数:26,代码来源:train.py

示例8: iterate

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def iterate(self, data):
        for key in data:
            t = data[key]
            if torch.is_tensor(t):
                data[key] = t.to(self.device)
        tensor = data['tensor']
        outputs = pybenchmark.profile('inference')(self.inference)(tensor)
        height, width = data['image'].size()[1:3]
        loss = pybenchmark.profile('loss')(model.Loss(self.config, data, self.limbs_index, height, width))
        losses = [loss(**output) for output in outputs]
        losses_hparam = [{name: self.loss_hparam(i, name, l) for name, l in loss.items()} for i, loss in enumerate(losses)]
        loss_total = sum(sum(loss.values()) for loss in losses_hparam)
        self.optimizer.zero_grad()
        loss_total.backward()
        try:
            clip = self.config.getfloat('train', 'clip')
            nn.utils.clip_grad_norm(self.inference.parameters(), clip)
        except configparser.NoOptionError:
            pass
        self.optimizer.step()
        return dict(
            height=height, width=width,
            data=data, outputs=outputs,
            loss_total=loss_total, losses=losses, losses_hparam=losses_hparam,
        ) 
开发者ID:ruiminshen,项目名称:openpose-pytorch,代码行数:27,代码来源:train.py

示例9: load_cifar10

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def load_cifar10(is_train, augs, batch_size, root='data/CIFAR-10'):
    dataset = torchvision.datasets.CIFAR10(root=root, train=is_train, transform=augs, download=False)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=is_train, num_workers=num_workers) 
开发者ID:wdxtub,项目名称:deep-learning-note,代码行数:5,代码来源:47_aug_cifar10.py

示例10: train_with_data_aug

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def train_with_data_aug(train_augs, test_augs, lr=0.001):
    batch_size, net = 256, utils.resnet18(10)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = torch.nn.CrossEntropyLoss()
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    utils.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=1) 
开发者ID:wdxtub,项目名称:deep-learning-note,代码行数:9,代码来源:47_aug_cifar10.py

示例11: make_args

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def make_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
    parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
    parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test'])
    parser.add_argument('--rows', default=3, type=int)
    parser.add_argument('--cols', default=3, type=int)
    parser.add_argument('--logging', default='logging.yml', help='logging config')
    return parser.parse_args() 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:11,代码来源:demo_data.py

示例12: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def main():
    args = make_args()
    config = configparser.ConfigParser()
    utils.load_config(config, args.config)
    for cmd in args.modify:
        utils.modify_config(config, cmd)
    with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
        logging.config.dictConfig(yaml.load(f))
    model_dir = utils.get_model_dir(config)
    category = utils.get_category(config)
    anchors = torch.from_numpy(utils.get_anchors(config)).contiguous()
    try:
        path, step, epoch = utils.train.load_model(model_dir)
        state_dict = torch.load(path, map_location=lambda storage, loc: storage)
    except (FileNotFoundError, ValueError):
        logging.warning('model cannot be loaded')
        state_dict = None
    dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config, state_dict), anchors, len(category))
    logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in dnn.state_dict().values())))
    if state_dict is not None:
        dnn.load_state_dict(state_dict)
    height, width = tuple(map(int, config.get('image', 'size').split()))
    image = torch.autograd.Variable(torch.randn(args.batch_size, 3, height, width))
    output = dnn(image)
    state_dict = dnn.state_dict()
    graph = utils.visualize.Graph(config, state_dict)
    graph(output.grad_fn)
    diff = [key for key in state_dict if key not in graph.drawn]
    if diff:
        logging.warning('variables not shown: ' + str(diff))
    path = graph.dot.view(os.path.basename(model_dir) + '.gv', os.path.dirname(model_dir))
    logging.info(path) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:34,代码来源:demo_graph.py

示例13: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def main():
    args = make_args()
    config = configparser.ConfigParser()
    utils.load_config(config, args.config)
    for cmd in args.modify:
        utils.modify_config(config, cmd)
    with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
        logging.config.dictConfig(yaml.load(f))
    torch.manual_seed(args.seed)
    cache_dir = utils.get_cache_dir(config)
    model_dir = utils.get_model_dir(config)
    category = utils.get_category(config, cache_dir if os.path.exists(cache_dir) else None)
    anchors = utils.get_anchors(config)
    anchors = torch.from_numpy(anchors).contiguous()
    path, step, epoch = utils.train.load_model(model_dir)
    state_dict = torch.load(path, map_location=lambda storage, loc: storage)
    dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config, state_dict), anchors, len(category))
    dnn.load_state_dict(state_dict)
    height, width = tuple(map(int, config.get('image', 'size').split()))
    tensor = torch.randn(1, 3, height, width)
    # Checksum
    for key, var in dnn.state_dict().items():
        a = var.cpu().numpy()
        print('\t'.join(map(str, [key, a.shape, utils.abs_mean(a), hashlib.md5(a.tostring()).hexdigest()])))
    output = dnn(torch.autograd.Variable(tensor, volatile=True)).data
    for key, a in [
        ('tensor', tensor.cpu().numpy()),
        ('output', output.cpu().numpy()),
    ]:
        print('\t'.join(map(str, [key, a.shape, utils.abs_mean(a), hashlib.md5(a.tostring()).hexdigest()]))) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:32,代码来源:checksum_torch.py

示例14: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def main():
    args = make_args()
    config = configparser.ConfigParser()
    utils.load_config(config, args.config)
    for cmd in args.modify:
        utils.modify_config(config, cmd)
    with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
        logging.config.dictConfig(yaml.load(f))
    model_dir = utils.get_model_dir(config)
    category = utils.get_category(config)
    anchors = torch.from_numpy(utils.get_anchors(config)).contiguous()
    path, step, epoch = utils.train.load_model(model_dir)
    state_dict = torch.load(path, map_location=lambda storage, loc: storage)
    _model = utils.parse_attr(config.get('model', 'dnn'))
    dnn = _model(model.ConfigChannels(config, state_dict), anchors, len(category))
    logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in dnn.state_dict().values())))
    dnn.load_state_dict(state_dict)
    height, width = tuple(map(int, config.get('image', 'size').split()))
    image = torch.autograd.Variable(torch.randn(args.batch_size, 3, height, width))
    output = dnn(image)
    state_dict = dnn.state_dict()
    d = utils.dense(state_dict[args.name])
    keep = torch.LongTensor(np.argsort(d)[:int(len(d) * args.keep)])
    modifier = utils.channel.Modifier(
        args.name, state_dict, dnn,
        lambda name, var: var[keep],
        lambda name, var, mapper: var[mapper(keep, len(d))],
        debug=args.debug,
    )
    modifier(output.grad_fn)
    if args.debug:
        path = modifier.dot.view('%s.%s.gv' % (os.path.basename(model_dir), os.path.basename(os.path.splitext(__file__)[0])), os.path.dirname(model_dir))
        logging.info(path)
    assert len(keep) == len(state_dict[args.name])
    dnn = _model(model.ConfigChannels(config, state_dict), anchors, len(category))
    dnn.load_state_dict(state_dict)
    dnn(image)
    if not args.debug:
        torch.save(state_dict, path) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:41,代码来源:pruner.py

示例15: load

# 需要导入模块: import utils [as 别名]
# 或者: from utils import train [as 别名]
def load(self):
        try:
            path, step, epoch = utils.train.load_model(self.model_dir)
            state_dict = torch.load(path, map_location=lambda storage, loc: storage)
            config_channels = model.ConfigChannels(self.config, state_dict)
        except (FileNotFoundError, ValueError):
            step, epoch = 0, 0
            config_channels = model.ConfigChannels(self.config)
        dnn = utils.parse_attr(self.config.get('model', 'dnn'))(config_channels, self.anchors, len(self.category))
        if config_channels.state_dict is not None:
            dnn.load_state_dict(config_channels.state_dict)
        return step, epoch, dnn 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:14,代码来源:train.py


注:本文中的utils.train方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。