本文整理汇总了Python中differential_privacy.multiple_teachers.input.ld_mnist方法的典型用法代码示例。如果您正苦于以下问题:Python input.ld_mnist方法的具体用法?Python input.ld_mnist怎么用?Python input.ld_mnist使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类differential_privacy.multiple_teachers.input
的用法示例。
在下文中一共展示了input.ld_mnist方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train_teacher
# 需要导入模块: from differential_privacy.multiple_teachers import input [as 别名]
# 或者: from differential_privacy.multiple_teachers.input import ld_mnist [as 别名]
def train_teacher(dataset, nb_teachers, teacher_id):
"""
This function trains a teacher (teacher id) among an ensemble of nb_teachers
models for the dataset specified.
:param dataset: string corresponding to dataset (svhn, cifar10)
:param nb_teachers: total number of teachers in the ensemble
:param teacher_id: id of the teacher being trained
:return: True if everything went well
"""
# If working directories do not exist, create them
assert input.create_dir_if_needed(FLAGS.data_dir)
assert input.create_dir_if_needed(FLAGS.train_dir)
# Load the dataset
if dataset == 'svhn':
train_data,train_labels,test_data,test_labels = input.ld_svhn(extended=True)
elif dataset == 'cifar10':
train_data, train_labels, test_data, test_labels = input.ld_cifar10()
elif dataset == 'mnist':
train_data, train_labels, test_data, test_labels = input.ld_mnist()
else:
print("Check value of dataset flag")
return False
# Retrieve subset of data for this teacher
data, labels = input.partition_dataset(train_data,
train_labels,
nb_teachers,
teacher_id)
print("Length of training data: " + str(len(labels)))
# Define teacher checkpoint filename and full path
if FLAGS.deeper:
filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '_deep.ckpt'
else:
filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '.ckpt'
ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + filename
# Perform teacher training
assert deep_cnn.train(data, labels, ckpt_path)
# Append final step value to checkpoint for evaluation
ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)
# Retrieve teacher probability estimates on the test data
teacher_preds = deep_cnn.softmax_preds(test_data, ckpt_path_final)
# Compute teacher accuracy
precision = metrics.accuracy(teacher_preds, test_labels)
print('Precision of teacher after training: ' + str(precision))
return True
示例2: train_teacher
# 需要导入模块: from differential_privacy.multiple_teachers import input [as 别名]
# 或者: from differential_privacy.multiple_teachers.input import ld_mnist [as 别名]
def train_teacher(dataset, nb_teachers, teacher_id):
"""
This function trains a teacher (teacher id) among an ensemble of nb_teachers
models for the dataset specified.
:param dataset: string corresponding to dataset (svhn, cifar10)
:param nb_teachers: total number of teachers in the ensemble
:param teacher_id: id of the teacher being trained
:return: True if everything went well
"""
# If working directories do not exist, create them
assert input.create_dir_if_needed(FLAGS.data_dir)
assert input.create_dir_if_needed(FLAGS.train_dir)
# Load the dataset
if dataset == 'svhn':
train_data,train_labels,test_data,test_labels = input.ld_svhn(extended=True)
elif dataset == 'cifar10':
train_data, train_labels, test_data, test_labels = input.ld_cifar10()
elif dataset == 'mnist':
train_data, train_labels, test_data, test_labels = input.ld_mnist()
else:
print("Check value of dataset flag")
return False
# Retrieve subset of data for this teacher
data, labels = input.partition_dataset(train_data,
train_labels,
nb_teachers,
teacher_id)
print("Length of training data: " + str(len(labels)))
# Define teacher checkpoint filename and full path
if FLAGS.deeper:
filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '_deep.ckpt'
else:
filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '.ckpt'
ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + filename
# Perform teacher training
assert deep_cnn.train(data, labels, ckpt_path)
# Append final step value to checkpoint for evaluation
ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)
# Retrieve teacher probability estimates on the test data
teacher_preds = deep_cnn.softmax_preds(test_data, ckpt_path_final)
# Compute teacher accuracy
precision = metrics.accuracy(teacher_preds, test_labels)
print('Precision of teacher after training: ' + str(precision))
return True