本文整理汇总了Python中roi_data.minibatch.get_minibatch方法的典型用法代码示例。如果您正苦于以下问题:Python minibatch.get_minibatch方法的具体用法?Python minibatch.get_minibatch怎么用?Python minibatch.get_minibatch使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类roi_data.minibatch
的用法示例。
在下文中一共展示了minibatch.get_minibatch方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_next_parallel_minibatch
# 需要导入模块: from roi_data import minibatch [as 别名]
# 或者: from roi_data.minibatch import get_minibatch [as 别名]
def get_next_parallel_minibatch(self, augmentation_process_pool, parallel_size):
"""Return the blobs to be used for the next minibatch. Thread safe."""
db_inds_list = [self._get_next_minibatch_inds() for _ in range(parallel_size)]
minibatch_db_list = [[self._roidb[i] for i in db_inds] for db_inds in db_inds_list]
copy_minibatch_db_list = copy.deepcopy(minibatch_db_list)
logger.info('get_next_parallel_minibatch MAP POOL Thread: {}'.format(threading.currentThread()))
res = augmentation_process_pool.map(get_minibatch, copy_minibatch_db_list)
logger.info('get_next_parallel_minibatch MAP POOL COMPLETE Thread: {}'.format(threading.currentThread()))
valid_blobs_list = [blobs for blobs, valid in res if valid]
del copy_minibatch_db_list
del minibatch_db_list
del res
return valid_blobs_list
示例2: __getitem__
# 需要导入模块: from roi_data import minibatch [as 别名]
# 或者: from roi_data.minibatch import get_minibatch [as 别名]
def __getitem__(self, index_tuple):
index, ratio = index_tuple
single_db = [self._roidb[index]]
blobs, valid = get_minibatch(single_db)
#TODO: Check if minibatch is valid ? If not, abandon it.
# Need to change _worker_loop in torch.utils.data.dataloader.py.
# Squeeze batch dim
for key in blobs:
if key != 'roidb':
blobs[key] = blobs[key].squeeze(axis=0)
if self._roidb[index]['need_crop']:
self.crop_data(blobs, ratio)
# Check bounding box
entry = blobs['roidb'][0]
boxes = entry['boxes']
invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3])
valid_inds = np.nonzero(~ invalid)[0]
if len(valid_inds) < len(boxes):
for key in ['boxes', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd',
'box_to_gt_ind_map', 'gt_keypoints']:
if key in entry:
entry[key] = entry[key][valid_inds]
entry['segms'] = [entry['segms'][ind] for ind in valid_inds]
blobs['roidb'] = blob_utils.serialize(blobs['roidb']) # CHECK: maybe we can serialize in collate_fn
return blobs
示例3: __getitem__
# 需要导入模块: from roi_data import minibatch [as 别名]
# 或者: from roi_data.minibatch import get_minibatch [as 别名]
def __getitem__(self, index_tuple):
index, ratio = index_tuple
single_db = [self._roidb[index]]
blobs, valid = get_minibatch(single_db, self._num_classes)
#TODO: Check if minibatch is valid ? If not, abandon it.
# Need to change _worker_loop in torch.utils.data.dataloader.py.
# Squeeze batch dim
# for key in blobs:
# if key != 'roidb':
# blobs[key] = blobs[key].squeeze(axis=0)
blobs['data'] = blobs['data'].squeeze(axis=0)
return blobs
示例4: get_next_minibatch
# 需要导入模块: from roi_data import minibatch [as 别名]
# 或者: from roi_data.minibatch import get_minibatch [as 别名]
def get_next_minibatch(self):
"""Return the blobs to be used for the next minibatch. Thread safe."""
valid = False
while not valid:
db_inds = self._get_next_minibatch_inds()
minibatch_db = [self._roidb[i] for i in db_inds]
blobs, valid = get_minibatch(minibatch_db)
return blobs
示例5: __getitem__
# 需要导入模块: from roi_data import minibatch [as 别名]
# 或者: from roi_data.minibatch import get_minibatch [as 别名]
def __getitem__(self, index_tuple):
index, ratio = index_tuple
single_db = [self._roidb[index]]
## _add_proposals(xxx)
blobs, valid = get_minibatch(single_db)
#TODO: Check if minibatch is valid ? If not, abandon it.
# Need to change _worker_loop in torch.utils.data.dataloader.py.
# Squeeze batch dim
for key in blobs:
if key != 'roidb':
blobs[key] = blobs[key].squeeze(axis=0)
if self._roidb[index]['need_crop']:
self.crop_data(blobs, ratio)
# Check bounding box
entry = blobs['roidb'][0]
boxes = entry['boxes']
invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3])
valid_inds = np.nonzero(~ invalid)[0]
if len(valid_inds) < len(boxes):
for key in ['boxes', 'precomp_keypoints', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd',
'box_to_gt_ind_map', 'gt_keypoints', 'gt_actions', 'gt_role_id']:
if key in entry:
entry[key] = entry[key][valid_inds]
entry['segms'] = [entry['segms'][ind] for ind in valid_inds]
blobs['roidb'] = blob_utils.serialize(blobs['roidb']) # CHECK: maybe we can serialize in collate_fn
return blobs
示例6: get_next_minibatch
# 需要导入模块: from roi_data import minibatch [as 别名]
# 或者: from roi_data.minibatch import get_minibatch [as 别名]
def get_next_minibatch(self):
"""Return the blobs to be used for the next minibatch. Thread safe."""
valid = False
while not valid:
db_inds = self._get_next_minibatch_inds()
minibatch_db = [self._roidb[i] for i in db_inds]
blobs, valid = get_minibatch(minibatch_db)
# for index, i in enumerate(db_inds):
# self._roidb[i] = new_roidb[index]
return blobs
示例7: get_next_minibatch
# 需要导入模块: from roi_data import minibatch [as 别名]
# 或者: from roi_data.minibatch import get_minibatch [as 别名]
def get_next_minibatch(self):
"""Return the blobs to be used for the next minibatch. Thread safe."""
valid = False
while not valid:
db_inds, dataset_ind = self._get_next_minibatch_inds()
minibatch_db = [self._roidbs[dataset_ind][i] for i in db_inds]
blobs, valid = get_minibatch(minibatch_db)
# for index, i in enumerate(db_inds):
# self._roidbs[dataset_ind][i] = new_roidb[index]
return blobs
示例8: __getitem__
# 需要导入模块: from roi_data import minibatch [as 别名]
# 或者: from roi_data.minibatch import get_minibatch [as 别名]
def __getitem__(self, index_tuple):
index, ratio = index_tuple
single_db = [self._roidb[index]]
blobs, valid = get_minibatch(single_db)
#TODO: Check if minibatch is valid ? If not, abandon it.
# Need to change _worker_loop in torch.utils.data.dataloader.py.
# Squeeze batch dim
for key in blobs:
if key != 'roidb':
# print('%s %s' % (key, type(blobs[key])))
# print(blobs[key].shape)
blobs[key] = blobs[key].squeeze(axis=0)
if self._roidb[index]['need_crop']:
self.crop_data(blobs, ratio)
# Check bounding box
entry = blobs['roidb'][0]
boxes = entry['boxes']
invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3])
valid_inds = np.nonzero(~ invalid)[0]
if len(valid_inds) < len(boxes):
for key in ['boxes', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd',
'box_to_gt_ind_map', 'gt_keypoints', 'gt_scores', 'dataset_id',
'gt_source']: # EDIT: gt_scores
if key in entry:
entry[key] = entry[key][valid_inds]
entry['segms'] = [entry['segms'][ind] for ind in valid_inds]
blobs['roidb'] = blob_utils.serialize(blobs['roidb']) # CHECK: maybe we can serialize in collate_fn
return blobs
示例9: _get_next_minibatch
# 需要导入模块: from roi_data import minibatch [as 别名]
# 或者: from roi_data.minibatch import get_minibatch [as 别名]
def _get_next_minibatch(self):
"""Return the blobs to be used for the next minibatch. DEPRECATED.
This only exists for debugging (in train_net.py) and for
benchmarking."""
roidb = self._roidb
valid = False
while not valid:
db_inds = self._get_next_minibatch_inds(
{'roidb': roidb}, self._lock,
multiprocessing.Value('i', self._cur, lock=False),
self._perm)
minibatch_db = [roidb[i] for i in db_inds]
blobs, valid = get_minibatch(minibatch_db)
return blobs
示例10: _get_next_minibatch2
# 需要导入模块: from roi_data import minibatch [as 别名]
# 或者: from roi_data.minibatch import get_minibatch [as 别名]
def _get_next_minibatch2(shared_readonly_dict, lock, mp_cur, mp_perm):
"""Return the blobs to be used for the next minibatch. Thread safe."""
roidb = shared_readonly_dict['roidb']
valid = False
while not valid:
db_inds = RoIDataLoader._get_next_minibatch_inds(
shared_readonly_dict, lock, mp_cur, mp_perm)
minibatch_db = [roidb[i] for i in db_inds]
blobs, valid = get_minibatch(minibatch_db)
return blobs