本文整理汇总了Python中tensorflow.python.ops.linalg_ops.norm函数的典型用法代码示例。如果您正苦于以下问题:Python norm函数的具体用法?Python norm怎么用?Python norm使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了norm函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testShapesValues
def testShapesValues(self):
gain = 3.14
for dtype in [dtypes.float32]:
for kernel_size in [[3], [8], [3, 5], [2, 4], [3, 3, 3], [2, 2, 2]]:
tol = 1e-2
# Check orthogonality by computing ratio between
# the 2-norms of the inputs and outputs.
if len(kernel_size) == 1:
shape = [4, 32, 64]
convolution = convolutional.conv1d
elif len(kernel_size) == 2:
convolution = convolutional.conv2d
shape = [4, 32, 32, 64]
else:
shape = [4, 16, 16, 16, 64]
convolution = convolutional.conv3d
inputs = random_ops.random_normal(shape, dtype=dtype)
inputs_2norm = linalg_ops.norm(inputs)
outputs = convolution(
inputs, padding="same", filters=128,
kernel_size=kernel_size, use_bias=False,
kernel_initializer=init_ops.convolutional_delta_orthogonal(
gain=gain))
outputs_shape = shape[0:-1] + [128]
outputs_2norm = linalg_ops.norm(outputs)
ratio = outputs_2norm / inputs_2norm
my_ops = variables.global_variables_initializer()
with self.test_session(use_gpu=True) as sess:
sess.run(my_ops)
# Check the shape of the outputs
t = outputs.eval()
self.assertAllEqual(t.shape, outputs_shape)
# Check isometry of the delta-orthogonal kernel.
self.assertAllClose(sess.run(ratio), np.sqrt(gain),
rtol=tol, atol=tol)
示例2: testInvalidAxis
def testInvalidAxis(self):
matrix = [[0., 1.], [2., 3.]]
for axis_ in [], [1, 2, 3], [[1]], [[1], [2]], [3.1415], [1, 1]:
error_prefix = ("'axis' must be None, an integer, or a tuple of 2 unique "
"integers")
with self.assertRaisesRegexp(ValueError, error_prefix):
linalg_ops.norm(matrix, axis=axis_)
示例3: _CompareNorm
def _CompareNorm(self, matrix):
np_norm = np.linalg.norm(matrix, ord=ord_, axis=axis_, keepdims=keep_dims_)
with self.cached_session(use_gpu=True) as sess:
if use_static_shape_:
tf_matrix = constant_op.constant(matrix)
tf_norm = linalg_ops.norm(
tf_matrix, ord=ord_, axis=axis_, keepdims=keep_dims_)
tf_norm_val = self.evaluate(tf_norm)
else:
tf_matrix = array_ops.placeholder(dtype_)
tf_norm = linalg_ops.norm(
tf_matrix, ord=ord_, axis=axis_, keepdims=keep_dims_)
tf_norm_val = sess.run(tf_norm, feed_dict={tf_matrix: matrix})
self.assertAllClose(np_norm, tf_norm_val, rtol=1e-5, atol=1e-5)
示例4: compute_lr
def compute_lr(self, grad, var):
scaled_lr = self._learning_rate
if self._skip_list is None or not any(v in var.name
for v in self._skip_list):
w_norm = linalg_ops.norm(var, ord=2)
g_norm = linalg_ops.norm(grad, ord=2)
trust_ratio = array_ops.where(
math_ops.greater(w_norm, 0),
array_ops.where(
math_ops.greater(g_norm, 0),
(self._eeta * w_norm /
(g_norm + self._weight_decay * w_norm + self._epsilon)), 1.0),
1.0)
scaled_lr = self._learning_rate * trust_ratio
return scaled_lr
示例5: testTransform
def testTransform(self):
# This tests all combinations of:
# - ids rank 0, 1, >1
# - params sharded/unsharded
# It always applies max_norm.
np.random.seed(8)
l2_norm = 2.
with self.test_session():
# Param values are in [l2_norm, l2_norm+1) so it will always clip.
params = np.random.rand(6, 3) + l2_norm
params_norm = l2_norm * params / np.sqrt(
np.sum(params * params, axis=1, keepdims=True))
# Compute the norm of each embedding. This will change the embedding
# rank to 0.
params_norm = np.linalg.norm(params_norm, axis=1)
transform = lambda x: linalg_ops.norm(x, axis=1)
for ids_shape in (), (3), (4, 3), (2, 3, 4):
# Test ids rank 0, 1, 2, 3.
ids = np.random.randint(
params.shape[0], size=np.prod(ids_shape,
dtype=np.int64)).reshape(ids_shape)
# Compare nonsharded to gather.
simple = embedding_ops._embedding_lookup_and_transform(
params, ids, max_norm=l2_norm, transform_fn=transform).eval()
self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval())
# Run a few different sharded versions.
for procs in 1, 2, 3:
stride = procs * math_ops.range(params.shape[0] // procs)
split_params = [
array_ops.gather(params, stride + p) for p in xrange(procs)
]
sharded = embedding_ops._embedding_lookup_and_transform(
split_params, ids, max_norm=l2_norm,
transform_fn=transform).eval()
self.assertAllEqual(simple, sharded)
示例6: body
def body(i, prev_c, prev_h, actions, log_probs):
# pylint: disable=g-long-lambda
signal = control_flow_ops.cond(
math_ops.equal(i, 0),
lambda: array_ops.tile(device_go_embedding,
[self.hparams.num_children, 1]),
lambda: embedding_ops.embedding_lookup(device_embeddings,
actions.read(i - 1))
)
if self.hparams.keep_prob is not None:
signal = nn_ops.dropout(signal, self.hparams.keep_prob)
next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias)
query = math_ops.matmul(next_h, attn_w_2)
query = array_ops.reshape(
query, [self.hparams.num_children, 1, self.hparams.hidden_size])
query = math_ops.tanh(query + attn_mem)
query = array_ops.reshape(query, [
self.hparams.num_children * self.num_groups, self.hparams.hidden_size
])
query = math_ops.matmul(query, attn_v)
query = array_ops.reshape(query,
[self.hparams.num_children, self.num_groups])
query = nn_ops.softmax(query)
query = array_ops.reshape(query,
[self.hparams.num_children, self.num_groups, 1])
query = math_ops.reduce_sum(attn_mem * query, axis=1)
query = array_ops.concat([next_h, query], axis=1)
logits = math_ops.matmul(query, device_softmax)
logits /= self.hparams.temperature
if self.hparams.tanh_constant > 0:
logits = math_ops.tanh(logits) * self.hparams.tanh_constant
if self.hparams.logits_std_noise > 0:
num_in_logits = math_ops.cast(
array_ops.size(logits), dtype=dtypes.float32)
avg_norm = math_ops.divide(
linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
logits_noise = random_ops.random_normal(
array_ops.shape(logits),
stddev=self.hparams.logits_std_noise * avg_norm)
logits = control_flow_ops.cond(
self.global_step > self.hparams.stop_noise_step, lambda: logits,
lambda: logits + logits_noise)
if mode == "sample":
next_y = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
elif mode == "greedy":
next_y = math_ops.argmax(logits, 1)
elif mode == "target":
next_y = array_ops.slice(y, [0, i], [-1, 1])
else:
raise NotImplementedError
next_y = math_ops.to_int32(next_y)
next_y = array_ops.reshape(next_y, [self.hparams.num_children])
actions = actions.write(i, next_y)
log_probs += nn_ops.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=next_y)
return i + 1, next_c, next_h, actions, log_probs
示例7: _verifySolve
def _verifySolve(self,
x,
y,
dtype,
use_placeholder,
fast,
l2_regularizer,
batch_shape=()):
if not fast and l2_regularizer != 0:
# The slow path does not support regularization.
return
maxdim = np.max(x.shape)
if dtype == np.float32 or dtype == np.complex64:
tol = maxdim * 5e-4
else:
tol = maxdim * 5e-7
a = x.astype(dtype)
b = y.astype(dtype)
if dtype in [np.complex64, np.complex128]:
a.imag = a.real
b.imag = b.real
# numpy.linalg.lstqr does not batching, so we just solve a single system
# and replicate the solution. and residual norm.
np_ans = _SolveWithNumpy(x, y, l2_regularizer=l2_regularizer)
np_r = np.dot(np.conj(a.T), b - np.dot(a, np_ans))
np_r_norm = np.sqrt(np.sum(np.conj(np_r) * np_r))
if batch_shape is not ():
a = np.tile(a, batch_shape + (1, 1))
b = np.tile(b, batch_shape + (1, 1))
np_ans = np.tile(np_ans, batch_shape + (1, 1))
np_r_norm = np.tile(np_r_norm, batch_shape)
with self.cached_session(use_gpu=fast) as sess:
if use_placeholder:
a_ph = array_ops.placeholder(dtypes.as_dtype(dtype))
b_ph = array_ops.placeholder(dtypes.as_dtype(dtype))
feed_dict = {a_ph: a, b_ph: b}
tf_ans = linalg_ops.matrix_solve_ls(
a_ph, b_ph, fast=fast, l2_regularizer=l2_regularizer)
else:
tf_ans = linalg_ops.matrix_solve_ls(
a, b, fast=fast, l2_regularizer=l2_regularizer)
feed_dict = {}
self.assertEqual(np_ans.shape, tf_ans.get_shape())
if l2_regularizer == 0:
# The least squares solution should satisfy A^H * (b - A*x) = 0.
tf_r = b - math_ops.matmul(a, tf_ans)
tf_r = math_ops.matmul(a, tf_r, adjoint_a=True)
tf_r_norm = linalg_ops.norm(tf_r, ord="fro", axis=[-2, -1])
tf_ans_val, tf_r_norm_val = sess.run(
[tf_ans, tf_r_norm], feed_dict=feed_dict)
self.assertAllClose(np_r_norm, tf_r_norm_val, atol=tol, rtol=tol)
else:
tf_ans_val = sess.run(tf_ans, feed_dict=feed_dict)
self.assertEqual(np_ans.shape, tf_ans_val.shape)
self.assertAllClose(np_ans, tf_ans_val, atol=2 * tol, rtol=2 * tol)
示例8: mean_only_frechet_classifier_distance_from_activations
def mean_only_frechet_classifier_distance_from_activations(
real_activations, generated_activations):
"""Classifier distance for evaluating a generative model from activations.
Given two Gaussian distribution with means m and m_w and covariance matrices
C and C_w, this function calcuates
|m - m_w|^2
which captures how different the distributions of real images and generated
images (or more accurately, their visual features) are. Note that unlike the
Inception score, this is a true distance and utilizes information about real
world images.
Note that when computed using sample means and sample covariance matrices,
Frechet distance is biased. It is more biased for small sample sizes. (e.g.
even if the two distributions are the same, for a small sample size, the
expected Frechet distance is large). It is important to use the same
sample size to compute frechet classifier distance when comparing two
generative models.
In this variant, we only compute the difference between the means of the
fitted Gaussians. The computation leads to O(n) vs. O(n^2) memory usage, yet
still retains much of the same information as FID.
Args:
real_activations: 2D array of activations of real images of size
[num_images, num_dims] to use to compute Frechet Inception distance.
generated_activations: 2D array of activations of generated images of size
[num_images, num_dims] to use to compute Frechet Inception distance.
Returns:
The mean-only Frechet Inception distance. A floating-point scalar of the
same type as the output of the activations.
"""
real_activations.shape.assert_has_rank(2)
generated_activations.shape.assert_has_rank(2)
activations_dtype = real_activations.dtype
if activations_dtype != dtypes.float64:
real_activations = math_ops.to_double(real_activations)
generated_activations = math_ops.to_double(generated_activations)
# Compute means of activations.
m = math_ops.reduce_mean(real_activations, 0)
m_w = math_ops.reduce_mean(generated_activations, 0)
# Next the distance between means.
mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm.
mofid = mean
if activations_dtype != dtypes.float64:
mofid = math_ops.cast(mofid, activations_dtype)
return mofid
示例9: Test
def Test(self):
np.random.seed(1)
n = shape_[-1]
batch_shape = shape_[:-2]
np_dtype = dtype_.as_numpy_dtype
a = np.random.uniform(
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
if dtype_.is_complex:
a += 1j * np.random.uniform(
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
a += np.conj(a.T)
a = np.tile(a, batch_shape + (1, 1))
# Optimal stepsize for central difference is O(epsilon^{1/3}).
epsilon = np.finfo(np_dtype).eps
delta = 0.1 * epsilon**(1.0 / 3.0)
# tolerance obtained by looking at actual differences using
# np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
if dtype_ in (dtypes_lib.float32, dtypes_lib.complex64):
tol = 1e-2
else:
tol = 1e-7
with self.test_session():
tf_a = constant_op.constant(a)
if compute_v_:
tf_e, tf_v = linalg_ops.self_adjoint_eig(tf_a)
# (complex) Eigenvectors are only unique up to an arbitrary phase
# We normalize the vectors such that the first component has phase 0.
reference = tf_v / linalg_ops.norm(
tf_v[..., 0:1, :], axis=-1, keep_dims=True)
tf_v *= math_ops.conj(reference)
outputs = [tf_e, tf_v]
else:
tf_e = linalg_ops.self_adjoint_eigvals(tf_a)
outputs = [tf_e,]
for b in outputs:
x_init = np.random.uniform(
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
if dtype_.is_complex:
x_init += 1j * np.random.uniform(
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
x_init += np.conj(x_init.T)
x_init = np.tile(x_init, batch_shape + (1, 1))
theoretical, numerical = gradient_checker.compute_gradient(
tf_a,
tf_a.get_shape().as_list(),
b,
b.get_shape().as_list(),
x_init_value=x_init,
delta=delta)
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
示例10: make_grouping_predictions
def make_grouping_predictions(self, input_layer, reuse=None):
"""model that predicts grouping (grouping_actions).
Args:
input_layer: group_input_layer
reuse: reuse
Returns:
grouping_actions: actions
grouping_log_probs: log probabilities corresponding to actions
"""
with variable_scope.variable_scope(self.hparams.name, reuse=True):
# input_layer: tensor of size [1, num_ops, hidden_size]
w_grouping_ff = variable_scope.get_variable("w_grouping_ff")
w_grouping_softmax = variable_scope.get_variable("w_grouping_softmax")
batch_size = array_ops.shape(input_layer)[0]
embedding_dim = array_ops.shape(input_layer)[2]
reshaped = array_ops.reshape(input_layer,
[batch_size * self.num_ops, embedding_dim])
ff_output = math_ops.matmul(reshaped, w_grouping_ff)
logits = math_ops.matmul(ff_output, w_grouping_softmax)
if self.hparams.logits_std_noise > 0:
num_in_logits = math_ops.cast(
array_ops.size(logits), dtype=dtypes.float32)
avg_norm = math_ops.divide(
linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
logits_noise = random_ops.random_normal(
array_ops.shape(logits),
stddev=self.hparams.logits_std_noise * avg_norm)
logits = control_flow_ops.cond(
self.global_step > self.hparams.stop_noise_step, lambda: logits,
lambda: logits + logits_noise)
logits = array_ops.reshape(logits,
[batch_size * self.num_ops, self.num_groups])
actions = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
actions = math_ops.to_int32(actions)
actions = array_ops.reshape(actions, [batch_size, self.num_ops])
action_label = array_ops.reshape(actions, [-1])
log_probs = nn_ops.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=action_label)
log_probs = array_ops.reshape(log_probs, [batch_size, -1])
log_probs = math_ops.reduce_sum(log_probs, 1)
grouping_actions = actions
grouping_log_probs = log_probs
return grouping_actions, grouping_log_probs
示例11: testTransform
def testTransform(self):
# This tests all combinations of:
# - ids rank 0, 1, >1
# - params sharded/unsharded
# It always applies max_norm.
np.random.seed(8)
l2_norm = 2.
with self.cached_session():
# Param values are in [l2_norm, l2_norm+1) so it will always clip.
params = np.random.rand(6, 3) + l2_norm
params_norm = l2_norm * params / np.sqrt(
np.sum(params * params, axis=1, keepdims=True))
# Compute the norm of each embedding. This will change the embedding
# rank to 0.
params_norm = np.linalg.norm(params_norm, axis=1)
transform = lambda x: linalg_ops.norm(x, axis=1)
for ids_shape in (), (3), (4, 3), (2, 3, 4):
# Test ids rank 0, 1, 2, 3.
ids = np.random.randint(
params.shape[0], size=np.prod(ids_shape,
dtype=np.int64)).reshape(ids_shape)
# Compare nonsharded to gather.
simple = embedding_ops._embedding_lookup_and_transform(
params, ids, max_norm=l2_norm, transform_fn=transform).eval()
self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval())
# Run a few different sharded versions.
for procs in 1, 2, 3:
stride = procs * math_ops.range(params.shape[0] // procs)
split_params = [
array_ops.gather(params, stride + p) for p in xrange(procs)
]
sharded = embedding_ops._embedding_lookup_and_transform(
split_params, ids, max_norm=l2_norm,
transform_fn=transform).eval()
# assertAllClose is used here as different implementations of sqrt may
# be used to compute each of the values being compared. For example,
# on AVX512 builds the embedding operation makes use of Eigen's fast
# vectorized square root algorithm for doubles. These different
# implementations of sqrt are not guaranteed to produce exactly the
# same results. Therefore, an exact comparison cannot be made.
self.assertAllClose(simple, sharded)
示例12: testBadOrder
def testBadOrder(self):
matrix = [[0., 1.], [2., 3.]]
for ord_ in "fro", -7, -1.1, 0:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported vector norm"):
linalg_ops.norm(matrix, ord=ord_)
for ord_ in "fro", -7, -1.1, 0:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported vector norm"):
linalg_ops.norm(matrix, ord=ord_, axis=-1)
for ord_ in "foo", -7, -1.1, 1.1:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported matrix norm"):
linalg_ops.norm(matrix, ord=ord_, axis=[-2, -1])
示例13: operator_and_matrix
def operator_and_matrix(
self, build_info, dtype, use_placeholder,
ensure_self_adjoint_and_pd=False):
shape = list(build_info.shape)
reflection_axis = linear_operator_test_util.random_sign_uniform(
shape[:-1], minval=1., maxval=2., dtype=dtype)
# Make sure unit norm.
reflection_axis = reflection_axis / linalg_ops.norm(
reflection_axis, axis=-1, keepdims=True)
lin_op_reflection_axis = reflection_axis
if use_placeholder:
lin_op_reflection_axis = array_ops.placeholder_with_default(
reflection_axis, shape=None)
operator = householder.LinearOperatorHouseholder(lin_op_reflection_axis)
mat = reflection_axis[..., array_ops.newaxis]
matrix = -2 * math_ops.matmul(mat, mat, adjoint_b=True)
matrix = array_ops.matrix_set_diag(
matrix, 1. + array_ops.matrix_diag_part(matrix))
return operator, matrix
示例14: frechet_classifier_distance_from_activations
def frechet_classifier_distance_from_activations(
real_activations, generated_activations):
"""Classifier distance for evaluating a generative model.
This is based on the Frechet Inception distance, but for an arbitrary
classifier.
This technique is described in detail in https://arxiv.org/abs/1706.08500.
Given two Gaussian distribution with means m and m_w and covariance matrices
C and C_w, this function calcuates
|m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
which captures how different the distributions of real images and generated
images (or more accurately, their visual features) are. Note that unlike the
Inception score, this is a true distance and utilizes information about real
world images.
Note that when computed using sample means and sample covariance matrices,
Frechet distance is biased. It is more biased for small sample sizes. (e.g.
even if the two distributions are the same, for a small sample size, the
expected Frechet distance is large). It is important to use the same
sample size to compute frechet classifier distance when comparing two
generative models.
Args:
real_activations: Real images to use to compute Frechet Inception distance.
generated_activations: Generated images to use to compute Frechet Inception
distance.
Returns:
The Frechet Inception distance. A floating-point scalar of the same type
as the output of the activations.
"""
real_activations.shape.assert_has_rank(2)
generated_activations.shape.assert_has_rank(2)
activations_dtype = real_activations.dtype
if activations_dtype != dtypes.float64:
real_activations = math_ops.to_double(real_activations)
generated_activations = math_ops.to_double(generated_activations)
# Compute mean and covariance matrices of activations.
m = math_ops.reduce_mean(real_activations, 0)
m_v = math_ops.reduce_mean(generated_activations, 0)
num_examples = math_ops.to_double(array_ops.shape(real_activations)[0])
# sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
real_centered = real_activations - m
sigma = math_ops.matmul(
real_centered, real_centered, transpose_a=True) / (num_examples - 1)
gen_centered = generated_activations - m_v
sigma_v = math_ops.matmul(
gen_centered, gen_centered, transpose_a=True) / (num_examples - 1)
# Find the Tr(sqrt(sigma sigma_v)) component of FID
sqrt_trace_component = trace_sqrt_product(sigma, sigma_v)
# Compute the two components of FID.
# First the covariance component.
# Here, note that trace(A + B) = trace(A) + trace(B)
trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component
# Next the distance between means.
mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm.
fid = trace + mean
if activations_dtype != dtypes.float64:
fid = math_ops.cast(fid, activations_dtype)
return fid
示例15: frechet_classifier_distance
def frechet_classifier_distance(real_images,
generated_images,
classifier_fn,
num_batches=1):
"""Classifier distance for evaluating a conditional generative model.
This is based on the Frechet Inception distance, but for an arbitrary
classifier.
This technique is described in detail in https://arxiv.org/abs/1706.08500.
Given two Gaussian distribution with means m and m_w and covariance matrices
C and C_w, this function calcuates
|m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
which captures how different the distributions of real images and generated
images (or more accurately, their visual features) are. Note that unlike the
Inception score, this is a true distance and utilizes information about real
world images.
Args:
real_images: Real images to use to compute Frechet Inception distance.
generated_images: Generated images to use to compute Frechet Inception
distance.
classifier_fn: A function that takes images and produces activations
based on a classifier.
num_batches: Number of batches to split images in to in order to
efficiently run them through the classifier network.
Returns:
The Frechet Inception distance. A floating-point scalar.
"""
real_images_list = array_ops.split(
real_images, num_or_size_splits=num_batches)
generated_images_list = array_ops.split(
generated_images, num_or_size_splits=num_batches)
imgs = array_ops.stack(real_images_list + generated_images_list)
# Compute the activations using the memory-efficient `map_fn`.
activations = functional_ops.map_fn(
fn=classifier_fn,
elems=imgs,
parallel_iterations=1,
back_prop=False,
swap_memory=True,
name='RunClassifier')
# Split the activations by the real and generated images.
real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)
# Ensure the activations have the right shapes.
real_a = array_ops.concat(array_ops.unstack(real_a), 0)
gen_a = array_ops.concat(array_ops.unstack(gen_a), 0)
real_a.shape.assert_has_rank(2)
gen_a.shape.assert_has_rank(2)
# Compute mean and covariance matrices of activations.
m = math_ops.reduce_mean(real_a, 0)
m_v = math_ops.reduce_mean(gen_a, 0)
dim = math_ops.to_float(array_ops.shape(m)[0])
sigma = math_ops.matmul(real_a - m, real_a - m, transpose_b=True) / dim
sigma_v = math_ops.matmul(gen_a - m, gen_a - m, transpose_b=True) / dim
# Take matrix square root of the product of covariance matrices.
sqcc = _matrix_square_root(math_ops.matmul(sigma, sigma_v))
# Compute the two components of FID.
trace = math_ops.trace(sigma + sigma_v - 2.0 * sqcc)
mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm.
fid = trace + mean
return fid