当前位置: 首页>>代码示例>>Python>>正文


Python provider.jitter_point_cloud方法代码示例

本文整理汇总了Python中provider.jitter_point_cloud方法的典型用法代码示例。如果您正苦于以下问题:Python provider.jitter_point_cloud方法的具体用法?Python provider.jitter_point_cloud怎么用?Python provider.jitter_point_cloud使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在provider的用法示例。


在下文中一共展示了provider.jitter_point_cloud方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train(net, opt, scheduler, train_loader, dev):

    net.train()

    total_loss = 0
    num_batches = 0
    total_correct = 0
    count = 0
    loss_f = nn.CrossEntropyLoss()
    with tqdm.tqdm(train_loader, ascii=True) as tq:
        for data, label in tq:
            data = data.data.numpy()
            data = provider.random_point_dropout(data)
            data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
            data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
            data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
            data = torch.tensor(data)
            label = label[:, 0]

            num_examples = label.shape[0]
            data, label = data.to(dev), label.to(dev).squeeze().long()
            opt.zero_grad()
            logits = net(data)
            loss = loss_f(logits, label)
            loss.backward()
            opt.step()

            _, preds = logits.max(1)

            num_batches += 1
            count += num_examples
            loss = loss.item()
            correct = (preds == label).sum().item()
            total_loss += loss
            total_correct += correct

            tq.set_postfix({
                'AvgLoss': '%.5f' % (total_loss / num_batches),
                'AvgAcc': '%.5f' % (total_correct / count)})
    scheduler.step() 
开发者ID:dmlc,项目名称:dgl,代码行数:42,代码来源:train_cls.py

示例2: get_example

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def get_example(self, i):
        """Return i-th data"""
        if self.augment:
            rotated_data = provider.rotate_point_cloud(
                self.data[i:i + 1, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)
            point_data = jittered_data[0]
        else:
            point_data = self.data[i]
        # pint_data (2048, 3): (num_point, k) --> convert to (k, num_point, 1)
        point_data = np.transpose(
            point_data.astype(np.float32), (1, 0))[:, :, None]
        assert point_data.dtype == np.float32
        assert self.label[i].dtype == np.int32
        return point_data, self.label[i] 
开发者ID:corochann,项目名称:chainer-pointnet,代码行数:17,代码来源:ply_dataset.py

示例3: train_one_epoch

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    # Shuffle train samples
    train_idxs = np.arange(0, len(TRAIN_DATASET))
    np.random.shuffle(train_idxs)
    num_batches = len(TRAIN_DATASET)/BATCH_SIZE
    
    log_string(str(datetime.now()))

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx+1) * BATCH_SIZE
        batch_data, batch_label = get_batch(TRAIN_DATASET, train_idxs, start_idx, end_idx)
        # Augment batched point clouds by rotation and jittering
        #aug_data = batch_data
        #aug_data = provider.random_scale_point_cloud(batch_data)
        batch_data[:,:,0:3] = provider.jitter_point_cloud(batch_data[:,:,0:3])
        feed_dict = {ops['pointclouds_pl']: batch_data,
                     ops['labels_pl']: batch_label,
                     ops['is_training_pl']: is_training,}
        summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
            ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 2)
        correct = np.sum(pred_val == batch_label)
        total_correct += correct
        total_seen += (BATCH_SIZE*NUM_POINT)
        loss_sum += loss_val

        if (batch_idx+1)%10 == 0:
            log_string(' -- %03d / %03d --' % (batch_idx+1, num_batches))
            log_string('mean loss: %f' % (loss_sum / 10))
            log_string('accuracy: %f' % (total_correct / float(total_seen)))
            total_correct = 0
            total_seen = 0
            loss_sum = 0 
开发者ID:pubgeo,项目名称:dfc2019,代码行数:43,代码来源:train.py

示例4: _augment_batch_data

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def _augment_batch_data(self, batch_data):
        rotated_data = provider.rotate_point_cloud(batch_data)
        rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
        jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
        jittered_data = provider.shift_point_cloud(jittered_data)
        jittered_data = provider.jitter_point_cloud(jittered_data)
        rotated_data[:,:,0:3] = jittered_data
        return provider.shuffle_points(rotated_data) 
开发者ID:pubgeo,项目名称:dfc2019,代码行数:10,代码来源:modelnet_h5_dataset.py

示例5: _augment_batch_data

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def _augment_batch_data(self, batch_data):
        if self.normal_channel:
            rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data)
        else:
            rotated_data = provider.rotate_point_cloud(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
    
        jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
        jittered_data = provider.shift_point_cloud(jittered_data)
        jittered_data = provider.jitter_point_cloud(jittered_data)
        rotated_data[:,:,0:3] = jittered_data
        return provider.shuffle_points(rotated_data) 
开发者ID:pubgeo,项目名称:dfc2019,代码行数:15,代码来源:modelnet_dataset.py

示例6: train_one_epoch

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    current_data, current_label, current_parts = data_utils.get_current_data_parts_h5(TRAIN_DATA, TRAIN_LABELS, TRAIN_PARTS, NUM_POINT)

    current_label = np.squeeze(current_label)
    current_parts = np.squeeze(current_parts)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_seen = 0
    loss_sum = 0
    total_correct_seg = 0    

    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx+1) * BATCH_SIZE

        # Augment batched point clouds by rotation and jittering
        rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
        jittered_data = provider.jitter_point_cloud(rotated_data)
        feed_dict = {ops['pointclouds_pl']: jittered_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['parts_pl']: current_parts[start_idx:end_idx],
                     ops['is_training_pl']: is_training,}
        summary, step, _, loss_val, seg_val = sess.run([ops['merged'], ops['step'],
            ops['train_op'], ops['loss'], ops['seg_pred']], feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        
        seg_val = np.argmax(seg_val, 2)
        seg_correct = np.sum(seg_val == current_parts[start_idx:end_idx])
        total_correct_seg += seg_correct

        total_seen += BATCH_SIZE
        loss_sum += loss_val

    
    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('seg accuracy: %f' % (total_correct_seg / (float(total_seen)*NUM_POINT))) 
开发者ID:hkust-vgd,项目名称:scanobjectnn,代码行数:42,代码来源:train_partseg.py

示例7: get_kmeans_init

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def get_kmeans_init(n_gaussians, cov_type='farthest'):
    D = 3

    # Get the training data for initialization
    # Load multiple models from the dataset
    points, labels, _, _ = provider.load_dataset( num_points=1024)
    mask = []
    for i in range(40):
        mask.append(np.squeeze(np.where(labels == i))[0:10])
    mask = np.concatenate(mask, axis=0)
    points = points[mask, :, :]
    points = provider.jitter_point_cloud(points, sigma=0.01, clip=0.05)
    points = np.concatenate(points, axis=0)

    #input function for kmeans clustering
    def input_fn():
        return tf.constant(points, dtype=tf.float32), None

    ## construct model
    kmeans = tf.contrib.learn.KMeansClustering(num_clusters=n_gaussians, relative_tolerance=0.0001)
    kmeans.fit(input_fn=input_fn)
    centers = kmeans.clusters()
    assignments = np.squeeze(list(kmeans.predict_cluster_idx(input_fn=input_fn)))

    n_points = points.shape[0]
    stdev = []
    w = []
    for i in range(n_gaussians):
        idx = np.squeeze(np.where(assignments == i))
        w.append(len(idx) / n_points)
        if cov_type == 'compute_cov':
            samples = points[idx, :].T
            stdev.append(np.sqrt(np.diag(np.cov(samples))))
        elif cov_type == 'farthest':
            d = np.sqrt(np.sum(np.power(points[idx, :] - centers[i, :], 2), axis=1))
            farthest_point_idx = np.argmax(d)
            stdev.append((np.max(d) / 3.) * np.ones(D))

    # gmm = GaussianMixture(n_components=n_gaussians, covariance_type='diag')

    return    w, centers.T, np.array(stdev, dtype=np.float32).T 
开发者ID:hkust-vgd,项目名称:scanobjectnn,代码行数:43,代码来源:tf_gmm_utils.py

示例8: train_one_epoch

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    if (".h5" in TRAIN_FILE):
        current_data, current_label = data_utils.get_current_data_h5(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)
    else:
        current_data, current_label = data_utils.get_current_data(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)


    current_label = np.squeeze(current_label)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0
       
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx+1) * BATCH_SIZE
        
        # Augment batched point clouds by rotation and jittering
        rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
        jittered_data = provider.jitter_point_cloud(rotated_data)

        feed_dict = {ops['pointclouds_pl']: jittered_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['is_training_pl']: is_training,}
        summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
            ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == current_label[start_idx:end_idx])
        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val
    
    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('accuracy: %f' % (total_correct / float(total_seen))) 
开发者ID:hkust-vgd,项目名称:scanobjectnn,代码行数:42,代码来源:train.py

示例9: get_batch

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def get_batch(self, data_aug=False):
        data, sem_label, ins_label = self.data_queue.get()

        if data_aug and self.split == 'train':
            data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])

        return data, sem_label, ins_label 
开发者ID:dlinzhao,项目名称:JSNet,代码行数:9,代码来源:dataset_s3dis.py

示例10: train_one_epoch

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer, epoch):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    batch_size = cfg.training.batch_size

    train_idxs = copy.deepcopy(TRAIN_INDICES)
    np.random.shuffle(train_idxs)
    num_batches = len(train_idxs) // batch_size

    loss_sum = 0

    pbar = tqdm(range(num_batches), desc=f'train', postfix=dict(last_loss_str=''))
    for batch_idx in pbar:
        #  logger.info('----- batch ' + str(batch_idx) + ' -----')
        start_idx = batch_idx * batch_size
        end_idx = (batch_idx + 1) * batch_size

        pcs1, pcs2, translations, rel_angles, pc1centers, pc2centers, pc1angles, pc2angles = provider.load_batch(train_idxs[start_idx:end_idx])

        # Augment batched point clouds by jittering
        pcs1 = provider.jitter_point_cloud(pcs1)
        pcs2 = provider.jitter_point_cloud(pcs2)
        feed_dict = {
            ops['pcs1']: pcs1,
            ops['pcs2']: pcs2,
            ops['translations']: translations,
            ops['rel_angles']: rel_angles,
            ops['is_training_pl']: is_training,
            ops['pc1centers']: pc1centers,
            ops['pc2centers']: pc2centers,
            ops['pc1angles']: pc1angles,
            ops['pc2angles']: pc2angles,
        }
        summary, step, _, loss_val, pred_translations, pred_remaining_angle_logits = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred_translations'], ops['pred_remaining_angle_logits']], feed_dict=feed_dict)
        #  step_in_epochs = float(epoch) + float(end_idx / len(train_idxs))
        train_writer.add_summary(summary, step)

        #  pred_val = np.argmax(pred_val, 1)
        #  correct = np.sum(pred_val == current_label[start_idx:end_idx])
        #  total_correct += correct
        #  total_seen += cfg.training.batch_size
        loss_sum += loss_val
        pbar.set_postfix(last_loss_str=f'{loss_val:.5f}')
        #  if batch_idx == 0:
        #  logger.info(np.concatenate([pred_val, transforms], axis=1)[:5,:])

    logger.info('train mean loss: %f' % (loss_sum / float(num_batches)))
    #  logger.info('accuracy: %f' % (total_correct / float(total_seen)))
    train_writer.flush() 
开发者ID:grossjohannes,项目名称:AlignNet-3D,代码行数:51,代码来源:train.py

示例11: train_one_epoch

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)
    
    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        current_data, current_label, _ = provider.loadDataFile_with_normal(TRAIN_FILES[train_file_idxs[fn]])
        current_data = current_data[:,0:NUM_POINT,:]
        current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label))           
        current_label = np.squeeze(current_label)

        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        
        total_correct = 0
        total_seen = 0
        loss_sum = 0
       
        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx+1) * BATCH_SIZE
            
            # Augment batched point clouds by rotation and jittering
            rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)

            feed_dict = {ops['pointclouds_pl']: jittered_data,
                         ops['labels_pl']: current_label[start_idx:end_idx],
                         ops['is_training_pl']: is_training,}
            summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val
        
        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen))) 
开发者ID:xyf513,项目名称:SpiderCNN,代码行数:46,代码来源:train_xyz.py

示例12: train_one_epoch

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    # Shuffle train samples
    train_idxs = np.arange(0, len(TRAIN_DATASET))
    np.random.shuffle(train_idxs)
    num_batches = len(TRAIN_DATASET)/BATCH_SIZE
    
    log_string(str(datetime.now()))

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx+1) * BATCH_SIZE
        batch_data, batch_label, batch_cls_label = get_batch(TRAIN_DATASET, train_idxs, start_idx, end_idx)
        # Augment batched point clouds by rotation and jittering
        #aug_data = batch_data
        #aug_data = provider.random_scale_point_cloud(batch_data)
        batch_data[:,:,0:3] = provider.jitter_point_cloud(batch_data[:,:,0:3])
        feed_dict = {ops['pointclouds_pl']: batch_data,
                     ops['labels_pl']: batch_label,
                     ops['cls_labels_pl']: batch_cls_label,
                     ops['is_training_pl']: is_training,}
        summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
            ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 2)
        correct = np.sum(pred_val == batch_label)
        total_correct += correct
        total_seen += (BATCH_SIZE*NUM_POINT)
        loss_sum += loss_val

        if (batch_idx+1)%10 == 0:
            log_string(' -- %03d / %03d --' % (batch_idx+1, num_batches))
            log_string('mean loss: %f' % (loss_sum / 10))
            log_string('accuracy: %f' % (total_correct / float(total_seen)))
            total_correct = 0
            total_seen = 0
            loss_sum = 0 
开发者ID:xyf513,项目名称:SpiderCNN,代码行数:44,代码来源:train.py

示例13: train_one_epoch

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)
    
    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        current_data, current_label, normal_data = provider.loadDataFile_with_normal(TRAIN_FILES[train_file_idxs[fn]])
        normal_data = normal_data[:,0:NUM_POINT,:]
        current_data = current_data[:,0:NUM_POINT,:]
        current_data, current_label, shuffle_idx = provider.shuffle_data(current_data, np.squeeze(current_label))           
        current_label = np.squeeze(current_label)
        normal_data = normal_data[shuffle_idx, ...]

        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        
        total_correct = 0
        total_seen = 0
        loss_sum = 0
       
        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx+1) * BATCH_SIZE
            
            # Augment batched point clouds by rotation and jittering
            rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)
            input_data = np.concatenate((jittered_data, normal_data[start_idx:end_idx, :, :]), 2)
            #random point dropout
            input_data = provider.random_point_dropout(input_data)


            feed_dict = {ops['pointclouds_pl']: input_data,
                         ops['labels_pl']: current_label[start_idx:end_idx],
                         ops['is_training_pl']: is_training,}
            summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val
        
        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen))) 
开发者ID:xyf513,项目名称:SpiderCNN,代码行数:52,代码来源:train.py

示例14: train_one_epoch

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    current_data, current_label, current_mask = data_utils.get_current_data_withmask_h5(TRAIN_DATA, TRAIN_LABELS, TRAIN_MASKS, NUM_POINT)

    current_label = np.squeeze(current_label)
    current_mask = np.squeeze(current_mask)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_correct_seg = 0    
    classify_loss_sum = 0
    seg_loss_sum = 0
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx+1) * BATCH_SIZE

        # Augment batched point clouds by rotation and jittering
        rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
        jittered_data = provider.jitter_point_cloud(rotated_data)
        feed_dict = {ops['pointclouds_pl']: jittered_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['masks_pl']: current_mask[start_idx:end_idx],
                     ops['is_training_pl']: is_training,}
        summary, step, _, loss_val, pred_val, seg_val, classify_loss, seg_loss = sess.run([ops['merged'], ops['step'],
            ops['train_op'], ops['loss'], ops['pred'], ops['seg_pred'], ops['classify_loss'], ops['seg_loss']], feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == current_label[start_idx:end_idx])
        
        seg_val = np.argmax(seg_val, 2)
        seg_correct = np.sum(seg_val == current_mask[start_idx:end_idx])
        total_correct_seg += seg_correct

        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val
        classify_loss_sum += classify_loss
        seg_loss_sum += seg_loss
    
    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('classify mean loss: %f' % (classify_loss_sum / float(num_batches)))
    log_string('seg mean loss: %f' % (seg_loss_sum / float(num_batches)))
    log_string('accuracy: %f' % (total_correct / float(total_seen)))
    log_string('seg accuracy: %f' % (total_correct_seg / (float(total_seen)*NUM_POINT))) 
开发者ID:hkust-vgd,项目名称:scanobjectnn,代码行数:51,代码来源:train_seg.py

示例15: train_one_epoch

# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, gmm, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    if (".h5" in TRAIN_FILE):
        current_data, current_label = data_utils.get_current_data_h5(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)
    else:
        current_data, current_label = data_utils.get_current_data(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)


    current_label = np.squeeze(current_label)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0

    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx + 1) * BATCH_SIZE

        # Augment batched point clouds by rotation and jittering

        augmented_data = current_data[start_idx:end_idx, :, :]
        if augment_scale:
            augmented_data = provider.scale_point_cloud(augmented_data, smin=0.66, smax=1.5)
        if augment_rotation:
            augmented_data = provider.rotate_point_cloud(augmented_data)
        if augment_translation:
            augmented_data = provider.translate_point_cloud(augmented_data, tval = 0.2)
        if augment_jitter:
            augmented_data = provider.jitter_point_cloud(augmented_data, sigma=0.01,
                                                    clip=0.05)  # default sigma=0.01, clip=0.05
        if augment_outlier:
            augmented_data = provider.insert_outliers_to_point_cloud(augmented_data, outlier_ratio=0.02)

        feed_dict = {ops['points_pl']: augmented_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['w_pl']: gmm.weights_,
                     ops['mu_pl']: gmm.means_,
                     ops['sigma_pl']: np.sqrt(gmm.covariances_),
                     ops['is_training_pl']: is_training, }
        summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                                                         ops['train_op'], ops['loss'], ops['pred']],
                                                        feed_dict=feed_dict)

        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == current_label[start_idx:end_idx])
        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val

    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('accuracy: %f' % (total_correct / float(total_seen))) 
开发者ID:hkust-vgd,项目名称:scanobjectnn,代码行数:58,代码来源:train.py


注:本文中的provider.jitter_point_cloud方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。