本文整理汇总了Python中model.model.Model方法的典型用法代码示例。如果您正苦于以下问题:Python model.Model方法的具体用法?Python model.Model怎么用?Python model.Model使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类model.model
的用法示例。
在下文中一共展示了model.Model方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_save_load_network
# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def test_save_load_network(self):
local_net = Net_arch(self.hp)
self.loss_f = nn.MSELoss()
local_model = Model(self.hp, local_net, self.loss_f)
self.model.save_network(self.logger)
save_filename = "%s_%d.pt" % (self.hp.log.name, self.model.step)
save_path = os.path.join(self.hp.log.chkpt_dir, save_filename)
self.hp.load.network_chkpt_path = save_path
assert os.path.exists(save_path) and os.path.isfile(save_path)
assert os.path.exists(self.hp.log.log_file_path) and os.path.isfile(
self.hp.log.log_file_path
)
local_model.load_network(logger=self.logger)
parameters = zip(
list(local_model.net.parameters()), list(self.model.net.parameters())
)
for load, origin in parameters:
assert (load == origin).all()
示例2: test_save_load_state
# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def test_save_load_state(self):
local_net = Net_arch(self.hp)
self.loss_f = nn.MSELoss()
local_model = Model(self.hp, local_net, self.loss_f)
self.model.save_training_state(self.logger)
save_filename = "%s_%d.state" % (self.hp.log.name, self.model.step)
save_path = os.path.join(self.hp.log.chkpt_dir, save_filename)
self.hp.load.resume_state_path = save_path
assert os.path.exists(save_path) and os.path.isfile(save_path)
assert os.path.exists(self.hp.log.log_file_path) and os.path.isfile(
self.hp.log.log_file_path
)
local_model.load_training_state(logger=self.logger)
parameters = zip(
list(local_model.net.parameters()), list(self.model.net.parameters())
)
for load, origin in parameters:
assert (load == origin).all()
assert local_model.epoch == self.model.epoch
assert local_model.step == self.model.step
示例3: main
# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def main(argv=None):
# Configurations
config = Config()
config.DATA_DIR = ['/data/']
config.LOG_DIR = './log/model'
config.MODE = 'training'
config.STEPS_PER_EPOCH_VAL = 180
config.display()
# Get images and labels.
dataset_train = Dataset(config, 'train')
# Build a Graph
model = Model(config)
# Train the model
model.compile()
model.train(dataset_train, None)
示例4: main
# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def main():
torch.set_num_threads(multiprocessing.cpu_count())
args = parse_args()
if args.set == 'gta':
from model.model import Model
elif args.set == 'kitti':
from model.model_cen import Model
else:
raise ValueError("Model not found")
model = Model(args.arch,
args.roi_name,
args.down_ratio,
args.roi_kernel)
model = nn.DataParallel(model)
model = model.to(args.device)
if args.phase == 'train':
run_training(model, args)
elif args.phase == 'test':
test_model(model, args)
示例5: setup_method
# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def setup_method(self, method):
super(TestModel, self).setup_method()
self.net = Net_arch(self.hp)
self.loss_f = nn.CrossEntropyLoss()
self.model = Model(self.hp, self.net, self.loss_f)
示例6: convert_layer_to_tensor
# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def convert_layer_to_tensor(layer, dtype=None, name=None, as_ref=False):
if not isinstance(layer, (Layer, Model)):
return NotImplemented
return layer.output
示例7: loadModelAndData
# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def loadModelAndData(num):
# Load dictionaries
with open('data/input_lang.index2word.json') as f:
input_lang_index2word = json.load(f)
with open('data/input_lang.word2index.json') as f:
input_lang_word2index = json.load(f)
with open('data/output_lang.index2word.json') as f:
output_lang_index2word = json.load(f)
with open('data/output_lang.word2index.json') as f:
output_lang_word2index = json.load(f)
# Reload existing checkpoint
model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index)
if args.load_param:
model.loadModel(iter=num)
# Load data
if os.path.exists(args.decode_output):
shutil.rmtree(args.decode_output)
os.makedirs(args.decode_output)
else:
os.makedirs(args.decode_output)
if os.path.exists(args.valid_output):
shutil.rmtree(args.valid_output)
os.makedirs(args.valid_output)
else:
os.makedirs(args.valid_output)
# Load validation file list:
with open('data/val_dials.json') as outfile:
val_dials = json.load(outfile)
# Load test file list:
with open('data/test_dials.json') as outfile:
test_dials = json.load(outfile)
return model, val_dials, test_dials
示例8: main
# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True,
help='path to image file.')
parser.add_argument('--checkpoint', type=str, default='data/model.ckpt',
help='path to image file.')
args = parser.parse_args()
params = {
'checkpoint': args.checkpoint,
'dataset':{
'dataset_dir': 'data',
'charset_filename': 'charset_size=63.txt',
'max_sequence_length': 30,
},
'beam_width': 1,
'summary': False
}
model = Model(params, ModeKeys.INFER)
image = tf.placeholder(tf.uint8, (1, 32, 100, 3), name='image')
predictions, _, _ = model({'image': image}, None)
assert os.path.exists(args.path), '%s does not exists!' % args.path
raw_image = Image.open(args.path).convert('RGB')
raw_image = raw_image.resize((100, 32), Image.BILINEAR)
raw_image = np.array(raw_image)[None, :]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
predictions = sess.run(predictions, feed_dict={image: raw_image})
text = predictions['predicted_text'][0]
print('%s: %s' % (args.path, text))
示例9: __init__
# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def __init__(self, classifier_data):
port = classifier_data.port
bufsize = classifier_data.bufsize
super().__init__(port, bufsize)
self.sess = tf.Session()
self.nn = Model()
self.nn.init(classifier_data.graph_path, classifier_data.checkpoint_path, self.sess)
self.lib = getLib()
示例10: main
# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def main(args, defaults):
parameters = process_args(args, defaults)
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s',
filename=parameters.log_path)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
model = Model(
phase = parameters.phase,
visualize = parameters.visualize,
data_path = parameters.data_path,
data_base_dir = parameters.data_base_dir,
output_dir = parameters.output_dir,
batch_size = parameters.batch_size,
initial_learning_rate = parameters.initial_learning_rate,
num_epoch = parameters.num_epoch,
steps_per_checkpoint = parameters.steps_per_checkpoint,
target_vocab_size = parameters.target_vocab_size,
model_dir = parameters.model_dir,
target_embedding_size = parameters.target_embedding_size,
attn_num_hidden = parameters.attn_num_hidden,
attn_num_layers = parameters.attn_num_layers,
clip_gradients = parameters.clip_gradients,
max_gradient_norm = parameters.max_gradient_norm,
load_model = parameters.load_model,
valid_target_length = float('inf'),
gpu_id=parameters.gpu_id,
use_gru=parameters.use_gru,
session = sess)
model.launch()
示例11: decode
# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def decode():
tfrecords_list, num_batches = read_list(FLAGS.lists_dir, FLAGS.data_type, FLAGS.batch_size)
with tf.Graph().as_default():
with tf.device('/cpu:0'):
with tf.name_scope('input'):
cmvn = np.load(FLAGS.inputs_cmvn)
cmvn_aux = np.load(FLAGS.inputs_cmvn.replace('cmvn', 'cmvn_aux'))
if FLAGS.with_labels:
inputs, inputs_cmvn, inputs_cmvn_aux, labels, lengths, lengths_aux = paddedFIFO_batch(tfrecords_list, FLAGS.batch_size,
FLAGS.input_size, FLAGS.output_size, cmvn=cmvn, cmvn_aux=cmvn_aux, with_labels=FLAGS.with_labels,
num_enqueuing_threads=1, num_epochs=1, shuffle=False)
else:
inputs, inputs_cmvn, inputs_cmvn_aux, lengths, lengths_aux = paddedFIFO_batch(tfrecords_list, FLAGS.batch_size,
FLAGS.input_size, FLAGS.output_size, cmvn=cmvn, cmvn_aux=cmvn_aux, with_labels=FLAGS.with_labels,
num_enqueuing_threads=1, num_epochs=1, shuffle=False)
labels = None
with tf.name_scope('model'):
model = Model(FLAGS, inputs, inputs_cmvn, inputs_cmvn_aux, labels, lengths, lengths_aux, infer=True)
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess = tf.Session()
sess.run(init)
checkpoint = tf.train.get_checkpoint_state(FLAGS.save_model_dir)
if checkpoint and checkpoint.model_checkpoint_path:
tf.logging.info("Restore best model from " + checkpoint.model_checkpoint_path)
model.saver.restore(sess, checkpoint.model_checkpoint_path)
else:
tf.logging.fatal("Checkpoint is not found, please check the best model save path is correct.")
sys.exit(-1)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
for batch in xrange(num_batches):
if coord.should_stop():
break
sep, mag_lengths = sess.run([model._sep, model._lengths])
for i in xrange(FLAGS.batch_size):
filename = tfrecords_list[FLAGS.batch_size*batch+i]
(_, name) = os.path.split(filename)
(uttid, _) = os.path.splitext(name)
noisy_file = os.path.join(FLAGS.noisy_dir, uttid + '.wav')
enhan_sig, rate = reconstruct(np.squeeze(sep[i,:mag_lengths[i],:]), noisy_file)
savepath = os.path.join(FLAGS.rec_dir, uttid + '.wav')
wav.write(savepath, rate, enhan_sig)
if (batch+1) % 100 == 0:
tf.logging.info("Number of batch processed: %d." % (batch+1))
except Exception, e:
coord.request_stop(e)
finally: