本文整理汇总了Python中config.model方法的典型用法代码示例。如果您正苦于以下问题:Python config.model方法的具体用法?Python config.model怎么用?Python config.model使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类config
的用法示例。
在下文中一共展示了config.model方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: load_session
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def load_session():
global sess_path, model_config, device, learning_rate, reset_optimizer
try:
sess = torch.load(sess_path)
if 'model_config' in sess and sess['model_config'] != model_config:
model_config = sess['model_config']
print('Use session config instead:')
print(utils.dict2params(model_config))
model_state = sess['model_state']
optimizer_state = sess['model_optimizer_state']
print('Session is loaded from', sess_path)
sess_loaded = True
except:
print('New session')
sess_loaded = False
model = PerformanceRNN(**model_config).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
if sess_loaded:
model.load_state_dict(model_state)
if not reset_optimizer:
optimizer.load_state_dict(optimizer_state)
return model, optimizer
示例2: save_activations
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def save_activations(model, inputs, files, layer, batch_number):
all_activations = []
ids = []
af = get_activation_function(model, layer)
for i in range(len(inputs)):
acts = get_activations(af, [inputs[i]])
all_activations.append(acts)
ids.append(files[i].split('/')[-2])
submission = pd.DataFrame(all_activations)
submission.insert(0, 'class', ids)
submission.reset_index()
if batch_number > 0:
submission.to_csv(config.activations_path, index=False, mode='a', header=False)
else:
submission.to_csv(config.activations_path, index=False)
示例3: parse_args
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument('--path', dest='path', help='Path to image', default=None, type=str)
parser.add_argument('--accuracy', action='store_true', help='To print accuracy score')
parser.add_argument('--plot_confusion_matrix', action='store_true')
parser.add_argument('--execution_time', action='store_true')
parser.add_argument('--store_activations', action='store_true')
parser.add_argument('--novelty_detection', action='store_true')
parser.add_argument('--model', type=str, required=True, help='Base model architecture',
choices=[config.MODEL_RESNET50, config.MODEL_RESNET152, config.MODEL_INCEPTION_V3,
config.MODEL_VGG16])
parser.add_argument('--data_dir', help='Path to data train directory')
parser.add_argument('--batch_size', default=500, type=int, help='How many files to predict on at once')
args = parser.parse_args()
return args
示例4: save_model
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def save_model():
global model, optimizer, model_config, sess_path
print('Saving to', sess_path)
torch.save({'model_config': model_config,
'model_state': model.state_dict(),
'model_optimizer_state': optimizer.state_dict()}, sess_path)
print('Done saving')
#========================================================================
# Training
#========================================================================
示例5: get_model
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def get_model():
model = resnet_50()
if config.model == "resnet18":
model = resnet_18()
if config.model == "resnet34":
model = resnet_34()
if config.model == "resnet101":
model = resnet_101()
if config.model == "resnet152":
model = resnet_152()
model.build(input_shape=(None, config.image_height, config.image_width, config.channels))
model.summary()
return model
示例6: train_step
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_object(y_true=labels, y_pred=predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
示例7: valid_step
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def valid_step(images, labels):
predictions = model(images, training=False)
v_loss = loss_object(labels, predictions)
valid_loss(v_loss)
valid_accuracy(labels, predictions)
# start training
示例8: save_history
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def save_history(history, prefix):
if 'acc' not in history.history:
return
if not os.path.exists(config.plots_dir):
os.mkdir(config.plots_dir)
img_path = os.path.join(config.plots_dir, '{}-%s.jpg'.format(prefix))
# summarize history for accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig(img_path % 'accuracy')
plt.close()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper right')
plt.savefig(img_path % 'loss')
plt.close()
示例9: get_model_class_instance
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def get_model_class_instance(*args, **kwargs):
module = importlib.import_module("models.{}".format(config.model))
return module.inst_class(*args, **kwargs)
示例10: parse_args
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, help='Base model architecture',
choices=[config.MODEL_RESNET50,
config.MODEL_RESNET152,
config.MODEL_INCEPTION_V3,
config.MODEL_VGG16])
parser.add_argument('--use_nn', action='store_true')
args = parser.parse_args()
return args
示例11: parse_args
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', help='Path to data dir')
parser.add_argument('--model', type=str, required=True, help='Base model architecture', choices=[
config.MODEL_RESNET50,
config.MODEL_RESNET152,
config.MODEL_INCEPTION_V3,
config.MODEL_VGG16])
parser.add_argument('--nb_epoch', type=int, default=1000)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--freeze_layers_number', type=int, help='will freeze the first N layers and unfreeze the rest')
return parser.parse_args()
示例12: train
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def train():
model = util.get_model_class_instance(
class_weight=util.get_class_weight(config.train_dir),
nb_epoch=args.nb_epoch,
batch_size=args.batch_size,
freeze_layers_number=args.freeze_layers_number)
model.train()
print('Training is finished!')
示例13: parse_args
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def parse_args():
parser = argparse.ArgumentParser('train net')
parser.add_argument('gpu_id', type=int)
parser.add_argument('model', type=str)
parser.add_argument('--restore', dest='restore', type=str)
parser.add_argument('--debug', dest='debug', type=bool, default=False)
parser.add_argument('--init_weights', dest='init_weights', type=str,
default='ResNet-50-model.caffemodel')
parser.add_argument('--step', dest='step', type=int, default=int(1e6))
parser.add_argument('--process', dest='process', type=int, default=3)
args = parser.parse_args()
return args
示例14: handle
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def handle(clientsocket):
while 1:
buf = clientsocket.recv(config.buffer_size)
if buf == 'exit'.encode():
return # client terminated connection
response = ''
if os.path.isfile(buf):
try:
img = [model_module.load_img(buf)]
out = model.predict(np.array(img))
prediction = np.argmax(out)
top10 = out[0].argsort()[-10:][::-1]
class_indices = dict(zip(config.classes, range(len(config.classes))))
keys = list(class_indices.keys())
values = list(class_indices.values())
answer = keys[values.index(prediction)]
try:
acts = util.get_activations(af, img)
predicted_relativity = novelty_detection_clf.predict(acts)[0]
nd_class = novelty_detection_clf.__classes[predicted_relativity]
except Exception as e:
print(e.message)
nd_class = 'related'
top10_json = "["
for i, t in enumerate(top10):
top10_json += '{"probability":"%s", "class":"%s"}%s' % (
out[0][t], keys[values.index(t)], '' if i == 9 else ',')
top10_json += "]"
response = '{"probability":"%s","class":"%s","relativity":"%s","top10":%s}' % (
out[0][prediction], answer, nd_class, top10_json)
print(response)
except Exception as e:
print('Error', e)
traceback.print_stack()
response = UNKNOWN_ERROR
else:
response = FILE_DOES_NOT_EXIST
clientsocket.sendall(response.encode())
示例15: predict
# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def predict(path):
files = get_files(path)
n_files = len(files)
print('Found {} files'.format(n_files))
if args.novelty_detection:
activation_function = util.get_activation_function(model, model_module.noveltyDetectionLayerName)
novelty_detection_clf = joblib.load(config.get_novelty_detection_model_path())
y_trues = []
predictions = np.zeros(shape=(n_files,))
nb_batch = int(np.ceil(n_files / float(args.batch_size)))
for n in range(0, nb_batch):
print('Batch {}'.format(n))
n_from = n * args.batch_size
n_to = min(args.batch_size * (n + 1), n_files)
y_true, inputs = get_inputs_and_trues(files[n_from:n_to])
y_trues += y_true
if args.store_activations:
util.save_activations(model, inputs, files[n_from:n_to], model_module.noveltyDetectionLayerName, n)
if args.novelty_detection:
activations = util.get_activations(activation_function, [inputs[0]])
nd_preds = novelty_detection_clf.predict(activations)[0]
print(novelty_detection_clf.__classes[nd_preds])
if not args.store_activations:
# Warm up the model
if n == 0:
print('Warming up the model')
start = time.clock()
model.predict(np.array([inputs[0]]))
end = time.clock()
print('Warming up took {} s'.format(end - start))
# Make predictions
start = time.clock()
out = model.predict(np.array(inputs))
end = time.clock()
predictions[n_from:n_to] = np.argmax(out, axis=1)
print('Prediction on batch {} took: {}'.format(n, end - start))
if not args.store_activations:
for i, p in enumerate(predictions):
recognized_class = list(classes_in_keras_format.keys())[list(classes_in_keras_format.values()).index(p)]
print('| should be {} ({}) -> predicted as {} ({})'.format(y_trues[i], files[i].split(os.sep)[-2], p,
recognized_class))
if args.accuracy:
print('Accuracy {}'.format(accuracy_score(y_true=y_trues, y_pred=predictions)))
if args.plot_confusion_matrix:
cnf_matrix = confusion_matrix(y_trues, predictions)
util.plot_confusion_matrix(cnf_matrix, config.classes, normalize=False)
util.plot_confusion_matrix(cnf_matrix, config.classes, normalize=True)