本文整理汇总了Python中neon.util.argparser.extract_valid_args函数的典型用法代码示例。如果您正苦于以下问题:Python extract_valid_args函数的具体用法?Python extract_valid_args怎么用?Python extract_valid_args使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了extract_valid_args函数的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
def main():
# Collect the user arguments and hyper parameters
args, hyper_params = get_args_and_hyperparameters()
np.set_printoptions( precision=8, suppress=True, edgeitems=6, threshold=2048)
# setup the CPU or GPU backend
be = gen_backend(**extract_valid_args(args, gen_backend))
# load the training dataset. This will download the dataset from the web and cache it
# locally for subsequent use.
train_set = MultiscaleSampler('trainval', '2007', samples_per_img=hyper_params.samples_per_img,
sample_height=224, path=args.data_dir,
samples_per_batch=hyper_params.samples_per_batch,
max_imgs = hyper_params.max_train_imgs,
shuffle = hyper_params.shuffle)
# create the model by replacing the classification layer of AlexNet with
# new adaptation layers
model, opt = create_model( args, hyper_params)
# Seed the Alexnet conv layers with pre-trained weights
if args.model_file is None and hyper_params.use_pre_trained_weights:
load_imagenet_weights(model, args.data_dir)
train( args, hyper_params, model, opt, train_set)
# Load the test dataset. This will download the dataset from the web and cache it
# locally for subsequent use.
test_set = MultiscaleSampler('test', '2007', samples_per_img=hyper_params.samples_per_img,
sample_height=224, path=args.data_dir,
samples_per_batch=hyper_params.samples_per_batch,
max_imgs = hyper_params.max_test_imgs,
shuffle = hyper_params.shuffle)
test( args, hyper_params, model, test_set)
return
示例2: gen_backend
args = parser.parse_args()
# hyperparameters
hidden_size = 128
embedding_dim = 128
vocab_size = 20000
sentence_length = 128
batch_size = 32
gradient_limit = 5
clip_gradients = True
num_epochs = args.epochs
embedding_update = True
# setup backend
be = gen_backend(**extract_valid_args(args, gen_backend))
# get the preprocessed and tokenized data
fname_h5, fname_vocab = build_data_train(filepath=args.review_file,
vocab_file=args.vocab_file, skip_headers=True)
# play around with google-news word vectors for init
if args.use_w2v:
w2v_file = args.w2v
vocab, rev_vocab = cPickle.load(open(fname_vocab, 'rb'))
init_emb, embedding_dim, _ = get_google_word2vec_W(w2v_file, vocab,
vocab_size=vocab_size, index_from=3)
print "Done loading the Word2Vec vectors: embedding size - {}".format(embedding_dim)
embedding_update = True
else:
示例3: buffering
help='How many backend buffers to use, 1 for no double buffering (saves gpu memory, slower)')
parser.add_argument('--plot_weight_layer', type=int, default=-1,
help='Plot weights for specified layer (must specify model_file)')
parser.add_argument('--plot_norm_per_filter', action="store_true",
help='With plotting weights, normalize each filter over range')
parser.add_argument('--plot_combine_chans', action="store_true",
help='With plotting weights, make plot that combines first three channels into colors')
parser.add_argument('--plot_log', action="store_true", help='Plot weights on log scale')
parser.add_argument('--plot_save_path', type=str, default='', help='Path to save weight plots instead of displaying')
# parse the command line arguments (generates the backend)
args = parser.parse_args(gen_be=False)
print('emneon / neon options:'); print(args)
# setup backend
be_args = extract_valid_args(args, gen_backend)
# mutiple gpus accessing the cache dir for autotuning winograd was causing crashes / reboots
#be_args['cache_dir'] = tempfile.mkdtemp() # create temp dir
be_args['deterministic'] = None # xxx - why was this set?
be = gen_backend(**be_args)
# xxx - this doesn't work, interrupt is caught by neon for saving the model which then raises KeyboardInterrupt
#def signal_handler(signal, frame):
# #print('You pressed Ctrl+C!')
# shutil.rmtree(be_args['cache_dir']) # delete directory
#signal.signal(signal.SIGINT, signal_handler)
# this function modified from cuda-convnets2 shownet.py
def make_filter_fig(filters, filter_start, fignum, _title, num_filters, combine_chans, FILTERS_PER_ROW=None,
plot_border=0.0):
MAX_ROWS = 24