本文整理匯總了Python中theano.compile.sharedvalue.SharedVariable方法的典型用法代碼示例。如果您正苦於以下問題:Python sharedvalue.SharedVariable方法的具體用法?Python sharedvalue.SharedVariable怎麽用?Python sharedvalue.SharedVariable使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類theano.compile.sharedvalue
的用法示例。
在下文中一共展示了sharedvalue.SharedVariable方法的11個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: base_variables
# 需要導入模塊: from theano.compile import sharedvalue [as 別名]
# 或者: from theano.compile.sharedvalue import SharedVariable [as 別名]
def base_variables(expression):
"""
A helper to find the base SharedVariables in a given expression.
Parameters
----------
expression : theano expression
The computation graph to find the base SharedVariables
Returns
-------
set(SharedVariable)
The set of unique shared variables
"""
variables = set()
if isinstance(expression, SharedVariable):
variables.add(expression)
return variables
elif hasattr(expression, 'owner') and expression.owner is not None:
for input in expression.owner.inputs:
variables.update(base_variables(input))
return variables
示例2: __init__
# 需要導入模塊: from theano.compile import sharedvalue [as 別名]
# 或者: from theano.compile.sharedvalue import SharedVariable [as 別名]
def __init__(self, *key, **kwargs):
if (len(key) >= 1 and
isinstance(key[0], dict) and
len(key[0]) > 1 and
not isinstance(key[0], OrderedDict)):
# Warn when using as input a non-ordered dictionary.
warnings.warn('Initializing an `OrderedUpdates` from a '
'non-ordered dictionary with 2+ elements could '
'make your code non-deterministic. You can use '
'an OrderedDict that is available at '
'theano.compat.OrderedDict for python 2.6+.')
super(OrderedUpdates, self).__init__(*key, **kwargs)
for key in self:
if not isinstance(key, SharedVariable):
raise TypeError(
'OrderedUpdates keys must inherit from SharedVariable',
key)
示例3: __init__
# 需要導入模塊: from theano.compile import sharedvalue [as 別名]
# 或者: from theano.compile.sharedvalue import SharedVariable [as 別名]
def __init__(self, nvis, beta, learn_beta=False, bias_from_marginals=None):
if not isinstance(beta, SharedVariable):
raise ValueError("beta needs to be a theano shared variable.")
self.__dict__.update(locals())
del self.self
# Don't serialize the dataset
del self.bias_from_marginals
self.space = VectorSpace(nvis)
self.input_space = self.space
origin = self.space.get_origin()
if bias_from_marginals is None:
init_bias = np.zeros((nvis,))
else:
init_bias = init_tanh_bias_from_marginals(bias_from_marginals)
self.bias = sharedX(init_bias, 'visible_bias')
示例4: prepare_updates_dict
# 需要導入模塊: from theano.compile import sharedvalue [as 別名]
# 或者: from theano.compile.sharedvalue import SharedVariable [as 別名]
def prepare_updates_dict(updates):
"""
Prepare a Theano `updates` dictionary.
Ensure that both keys and values are valid entries.
NB, this function is heavily coupled with its clients, and not intended for
general use..
"""
def prepare_key(key, val):
if not isinstance(key, SharedVariable):
if isinstance(key.owner.inputs[0], SharedVariable):
# Extract shared from Update(shared)
return key.owner.inputs[0]
elif key.owner.inputs[0].owner.op.__class__ is HostFromGpu:
if isinstance(key.owner.inputs[0].owner.inputs[0], SharedVariable):
# Extract shared from Update(HostFromGpu(shared))
return key.owner.inputs[0].owner.inputs[0]
elif key.owner.op.__class__ is ifelse.IfElse:
# Assume that 'true' condition of ifelse involves the intended
# shared variable.
return prepare_key(key.owner.inputs[1], val)
raise ValueError("Invalid updates dict key/value: %s / %s"
% (key, val))
return key
return {prepare_key(key, val): val for key, val in updates.iteritems()}
示例5: _pfunc_param_to_in
# 需要導入模塊: from theano.compile import sharedvalue [as 別名]
# 或者: from theano.compile.sharedvalue import SharedVariable [as 別名]
def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
if isinstance(param, Constant):
raise TypeError('Constants not allowed in param list', param)
if isinstance(param, Variable): # N.B. includes SharedVariable
return In(variable=param, strict=strict, allow_downcast=allow_downcast)
elif isinstance(param, In):
return param
raise TypeError('Unknown parameter type: %s' % type(param))
示例6: test_create_numpy_strict_false
# 需要導入模塊: from theano.compile import sharedvalue [as 別名]
# 或者: from theano.compile.sharedvalue import SharedVariable [as 別名]
def test_create_numpy_strict_false(self):
# here the value is perfect, and we're not strict about it,
# so creation should work
SharedVariable(
name='u',
type=Tensor(broadcastable=[False], dtype='float64'),
value=numpy.asarray([1., 2.]),
strict=False)
# here the value is castable, and we're not strict about it,
# so creation should work
SharedVariable(
name='u',
type=Tensor(broadcastable=[False], dtype='float64'),
value=[1., 2.],
strict=False)
# here the value is castable, and we're not strict about it,
# so creation should work
SharedVariable(
name='u',
type=Tensor(broadcastable=[False], dtype='float64'),
value=[1, 2], # different dtype and not a numpy array
strict=False)
# here the value is not castable, and we're not strict about it,
# this is beyond strictness, it must fail
try:
SharedVariable(
name='u',
type=Tensor(broadcastable=[False], dtype='float64'),
value=dict(), # not an array by any stretch
strict=False)
assert 0
except TypeError:
pass
示例7: test_use_numpy_strict_false
# 需要導入模塊: from theano.compile import sharedvalue [as 別名]
# 或者: from theano.compile.sharedvalue import SharedVariable [as 別名]
def test_use_numpy_strict_false(self):
# here the value is perfect, and we're not strict about it,
# so creation should work
u = SharedVariable(
name='u',
type=Tensor(broadcastable=[False], dtype='float64'),
value=numpy.asarray([1., 2.]),
strict=False)
# check that assignments to value are cast properly
u.set_value([3, 4])
assert type(u.get_value()) is numpy.ndarray
assert str(u.get_value(borrow=True).dtype) == 'float64'
assert numpy.all(u.get_value() == [3, 4])
# check that assignments of nonsense fail
try:
u.set_value('adsf')
assert 0
except ValueError:
pass
# check that an assignment of a perfect value results in no copying
uval = theano._asarray([5, 6, 7, 8], dtype='float64')
u.set_value(uval, borrow=True)
assert u.get_value(borrow=True) is uval
示例8: __call__
# 需要導入模塊: from theano.compile import sharedvalue [as 別名]
# 或者: from theano.compile.sharedvalue import SharedVariable [as 別名]
def __call__(self, obj):
if isinstance(obj, SharedVariable):
if obj.name:
if obj.name == 'pkl':
ValueError("can't pickle shared variable with name `pkl`")
self.ndarray_names[id(obj.container.storage[0])] = obj.name
elif not self.allow_unnamed:
raise ValueError("unnamed shared variable, {0}".format(obj))
return super(PersistentSharedVariableID, self).__call__(obj)
示例9: __setitem__
# 需要導入模塊: from theano.compile import sharedvalue [as 別名]
# 或者: from theano.compile.sharedvalue import SharedVariable [as 別名]
def __setitem__(self, key, value):
if isinstance(key, SharedVariable):
# TODO: consider doing error-checking on value.
# insist that it is a Theano variable? Have the right type?
# This could have weird consequences - for example a
# GPU SharedVariable is customarily associated with a TensorType
# value. Should it be cast to a GPU value right away? Should
# literals be transformed into constants immediately?
return super(OrderedUpdates, self).__setitem__(key, value)
else:
raise TypeError('OrderedUpdates keys must inherit from '
'SharedVariable', key)
示例10: processTogglePlayQueue
# 需要導入模塊: from theano.compile import sharedvalue [as 別名]
# 或者: from theano.compile.sharedvalue import SharedVariable [as 別名]
def processTogglePlayQueue(self):
from theano.compile.sharedvalue import SharedVariable
parameters = self.togglePlayQueue.get()
parametersRequiringReset = ['microphoneSeparationInMetres', 'numTDOAs', 'numSources', 'targetMode',
'dictionarySize', 'dictionaryType', 'gccPHATNLEnabled']
resetGCCNMFProcessor = False
for parameterName, parameterValue in parameters.items():
if not hasattr(self.gccNMFProcessor, parameterName):
logging.info('GCCNMFProcessor: setting %s: %s' % (parameterName, parameterValue))
setattr(self.gccNMFProcessor, parameterName, parameterValue)
resetGCCNMFProcessor |= parameterName in parametersRequiringReset
else:
currentParam = getattr(self.gccNMFProcessor, parameterName)
if issubclass(type(currentParam), SharedVariable):
if currentParam.get_value() != parameterValue:
logging.info('GCCNMFProcessor: setting %s: %s (shared)' % (parameterName, parameterValue))
currentParam.set_value(parameterValue)
else:
logging.info('GCCNMFProcessor: %s unchanged: %s (shared)' % (parameterName, parameterValue))
else:
if currentParam != parameterValue:
logging.info('GCCNMFProcessor: setting %s: %s' % (parameterName, parameterValue))
setattr(self.gccNMFProcessor, parameterName, parameterValue)
else:
logging.info('GCCNMFProcessor: %s unchanged: %s' % (parameterName, parameterValue))
resetGCCNMFProcessor |= parameterName in parametersRequiringReset
if resetGCCNMFProcessor:
self.gccNMFProcessor.reset()
示例11: _get_test_value
# 需要導入模塊: from theano.compile import sharedvalue [as 別名]
# 或者: from theano.compile.sharedvalue import SharedVariable [as 別名]
def _get_test_value(cls, v):
"""
Extract test value from variable v.
Raises AttributeError if there is none.
For a Constant, the test value is v.value.
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.
"""
# avoid circular import
from theano.compile.sharedvalue import SharedVariable
if isinstance(v, graph.Constant):
return v.value
elif isinstance(v, SharedVariable):
return v.get_value(borrow=True, return_internal_type=True)
elif isinstance(v, graph.Variable) and hasattr(v.tag, 'test_value'):
# ensure that the test value is correct
try:
ret = v.type.filter(v.tag.test_value)
except Exception as e:
# Better error message.
detailed_err_msg = (
"For compute_test_value, one input test value does not"
" have the requested type.\n")
tr = getattr(v.tag, 'trace', [])
if isinstance(tr, list) and len(tr) > 0:
detailed_err_msg += (
" \nBacktrace when that variable is created:\n")
# Print separate message for each element in the list
# of batcktraces
sio = StringIO()
for subtr in tr:
traceback.print_list(subtr, sio)
detailed_err_msg += str(sio.getvalue())
detailed_err_msg += (
"\nThe error when converting the test value to that"
" variable type:")
# We need to only have 1 args and it should be of type
# string. Otherwise, it print the tuple and so the
# new line do not get printed.
args = (detailed_err_msg,) + tuple(str(arg) for arg in e.args)
e.args = ("\n".join(args),)
raise
return ret
raise AttributeError('%s has no test value' % v)