本文整理汇总了Python中config.ModelConfig方法的典型用法代码示例。如果您正苦于以下问题:Python config.ModelConfig方法的具体用法?Python config.ModelConfig怎么用?Python config.ModelConfig使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类config
的用法示例。
在下文中一共展示了config.ModelConfig方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: training_task
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def training_task(self):
model_conf = ModelConfig(project_name=self.current_project)
self.current_task = Trains(model_conf)
try:
self.button_state(self.btn_training, tk.DISABLED)
self.button_state(self.btn_stop, tk.NORMAL)
self.is_task_running = True
self.current_task.train_process()
status = 'Training completed'
except Exception as e:
traceback.print_exc()
messagebox.showerror(
e.__class__.__name__, json.dumps(e.args, ensure_ascii=False)
)
status = 'Training failure'
self.button_state(self.btn_training, tk.NORMAL)
self.button_state(self.btn_stop, tk.DISABLED)
self.comb_project_name['state'] = tk.NORMAL
self.is_task_running = False
tk.messagebox.showinfo('Training Status', status)
示例2: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def __init__(self, model_conf: ModelConfig, outputs):
self.model_conf = model_conf
self.max_label_num = self.model_conf.max_label_num
if self.max_label_num == -1:
exception(text="The scene must set the maximum number of label (MaxLabelNum)", code=-998)
self.category_num = self.model_conf.category_num
flatten = tf.keras.layers.Flatten()(outputs)
shape_list = flatten.get_shape().as_list()
# print(shape_list[1], self.max_label_num)
outputs = tf.keras.layers.Reshape([self.max_label_num, int(shape_list[1] / self.max_label_num)])(flatten)
self.outputs = tf.keras.layers.Dense(
input_shape=outputs.shape,
units=self.category_num,
)(inputs=outputs)
print("output to reshape ----------- ", self.outputs.shape)
self.outputs = tf.keras.layers.Reshape([self.max_label_num, self.category_num])(self.outputs)
示例3: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def __init__(self, model_conf: ModelConfig, mode: RunMode, ran_captcha=None):
"""
:param model_conf: 工程配置
:param mode: 运行模式(区分:训练/验证)
"""
self.model_conf = model_conf
self.mode = mode
self.path_map = {
RunMode.Trains: self.model_conf.trains_path[DatasetType.TFRecords],
RunMode.Validation: self.model_conf.validation_path[DatasetType.TFRecords]
}
self.batch_map = {
RunMode.Trains: self.model_conf.batch_size,
RunMode.Validation: self.model_conf.validation_batch_size
}
self.data_dir = self.path_map[mode]
self.next_element = None
self.image_path = []
self.label_list = []
self._label_list = []
self._size = 0
self.encoder = Encoder(self.model_conf, self.mode)
self.ran_captcha = ran_captcha
示例4: eval
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def eval(ckpt, use_emb=False):
# Recommended hyperparameters
args = ModelConfig(batch_size=64, ckpt=ckpt, dropout=0.5,
use_emb=use_emb)
m = DictionaryModel(dict_field.vocab, output_size=att_crit.input_size, embed_input=args.use_emb,
dropout_rate=args.dropout)
m.load_state_dict(torch.load(args.ckpt)['m_state_dict'])
if torch.cuda.is_available():
m.cuda()
att_crit.cuda()
train_data.atts_matrix.cuda()
val_data.atts_matrix.cuda()
test_data.atts_matrix.cuda()
m.eval()
# Don't take the mean until the end
preds_ = []
labels_ = []
for val_b, (atts, words, defns, perm) in enumerate(tqdm(test_iter)):
preds_.append(att_crit.predict(m(defns, words))[perm])
labels_.append(atts.data.cpu().numpy()[perm])
preds = np.concatenate(preds_, 0)
labels = np.concatenate(labels_, 0)
acc_table = evaluate_accuracy(preds, labels).T.squeeze()
return acc_table, preds
示例5: make_dataset
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def make_dataset(self):
if not self.current_project:
messagebox.showerror(
"Error!", "Please set the project name first."
)
return
if self.is_task_running:
messagebox.showerror(
"Error!", "Please terminate the current training first or wait for the training to end."
)
return
self.save_conf()
self.button_state(self.btn_make_dataset, tk.DISABLED)
model_conf = ModelConfig(self.current_project)
train_path = self.dataset_value(DatasetType.Directory, RunMode.Trains)
validation_path = self.dataset_value(DatasetType.Directory, RunMode.Validation)
if len(train_path) < 1:
messagebox.showerror(
"Error!", "{} Sample set has not been added.".format(RunMode.Trains.value)
)
self.button_state(self.btn_make_dataset, tk.NORMAL)
return
self.threading_exec(
lambda: DataSets(model_conf).make_dataset(
trains_path=train_path,
validation_path=validation_path,
is_add=False,
callback=lambda: self.button_state(self.btn_make_dataset, tk.NORMAL),
msg=lambda x: tk.messagebox.showinfo('Make Dataset Status', x)
)
)
示例6: compile_task
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def compile_task(self):
if not self.current_project:
messagebox.showerror(
"Error!", "Please set the project name first."
)
return
model_conf = ModelConfig(project_name=self.current_project)
if not os.path.exists(model_conf.model_root_path):
messagebox.showerror(
"Error", "Model storage folder does not exist."
)
return
if len(os.listdir(model_conf.model_root_path)) < 3:
messagebox.showerror(
"Error", "There is no training model record, please train before compiling."
)
return
try:
if not self.current_task:
self.current_task = Trains(model_conf)
self.current_task.compile_graph(0)
status = 'Compile completed'
except Exception as e:
messagebox.showerror(
e.__class__.__name__, json.dumps(e.args, ensure_ascii=False)
)
status = 'Compile failure'
tk.messagebox.showinfo('Compile Status', status)
示例7: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils):
"""
:param model_conf: 从配置文件
:param inputs: 网络上一层输入 tf.keras.layers.Input / tf.Tensor 类型
:param utils: 网络工具类
"""
self.model_conf = model_conf
self.inputs = inputs
self.utils = utils
self.loss_func = self.model_conf.loss_func
示例8: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils):
"""
:param model_conf: 配置
:param inputs: 网络上一层输入 tf.keras.layers.Input / tf.Tensor 类型
:param utils: 网络工具类
"""
self.model_conf = model_conf
self.inputs = inputs
self.utils = utils
self.layer = None
示例9: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils):
self.model_conf = model_conf
self.inputs = inputs
self.utils = utils
self.loss_func = self.model_conf.loss_func
self.type = {
'121': [6, 12, 24, 16],
'169': [6, 12, 32, 32],
'201': [6, 12, 48, 32]
}
self.blocks = self.type['121']
self.padding = "SAME"
示例10: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils):
self.model_conf = model_conf
self.inputs = inputs
self.utils = utils
self.layer = None
示例11: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils):
self.model_conf = model_conf
self.inputs = inputs
self.utils = utils
self.loss_func = self.model_conf.loss_func
示例12: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils):
self.model_conf = model_conf
self.inputs = inputs
self.utils = utils
self.loss_func = self.model_conf.loss_func
self.last_block_filters = 1280
self.padding = "SAME"
示例13: parse_model
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def parse_model(source_bytes: bytes, key=None):
split_tag = b'-#||#-'
if not key:
key = [b"_____" + i.encode("utf8") + b"_____" for i in "&coriander"]
if isinstance(key, str):
key = [b"_____" + i.encode("utf8") + b"_____" for i in key]
key_len_int = len(key)
model_bytes_list = []
graph_bytes_list = []
slice_index = source_bytes.index(key[0])
split_tag_len = len(split_tag)
slice_0 = source_bytes[0: slice_index].split(split_tag)
model_slice_len = len(slice_0[1])
graph_slice_len = len(slice_0[0])
slice_len = split_tag_len + model_slice_len + graph_slice_len
for i in range(key_len_int-1):
slice_index = source_bytes.index(key[i])
print(slice_index, slice_index - slice_len)
slices = source_bytes[slice_index - slice_len: slice_index].split(split_tag)
model_bytes_list.append(slices[1])
graph_bytes_list.append(slices[0])
slices = source_bytes.split(key[-2])[1][:-len(key[-1])].split(split_tag)
model_bytes_list.append(slices[1])
graph_bytes_list.append(slices[0])
model_bytes = b"".join(model_bytes_list)
model_conf: ModelConfig = pickle.loads(model_bytes)
graph_bytes: bytes = b"".join(graph_bytes_list)
return model_conf, graph_bytes
示例14: output_model
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def output_model(project_name: str, model_type: ModelType, key=None):
model_conf = ModelConfig(project_name, is_dev=False)
graph_parent_path = model_conf.compile_model_path
model_suffix = COMPILE_MODEL_MAP[model_type]
model_bytes = pickle.dumps(model_conf.conf)
graph_path = os.path.join(graph_parent_path, "{}{}".format(model_conf.model_name, model_suffix))
with open(graph_path, "rb") as f:
graph_bytes = f.read()
output_path = graph_path.replace(".pb", ".pl").replace(".onnx", ".pl").replace(".tflite", ".pl")
concat_model(output_path, model_bytes, graph_bytes, key)
示例15: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import ModelConfig [as 别名]
def __init__(self, model_conf: ModelConfig):
self.model_conf = model_conf
self.category_num = self.model_conf.category_num