本文整理汇总了Python中utils.INF属性的典型用法代码示例。如果您正苦于以下问题:Python utils.INF属性的具体用法?Python utils.INF怎么用?Python utils.INF使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类utils
的用法示例。
在下文中一共展示了utils.INF属性的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: wastar_search
# 需要导入模块: import utils [as 别名]
# 或者: from utils import INF [as 别名]
def wastar_search(start_v, end_v, neighbors_fn, cost_fn=unit_cost_fn,
heuristic_fn=zero_heuristic_fn, w=1, max_cost=INF, max_time=INF):
# TODO: lazy wastar to get different paths
#heuristic_fn = lambda v: cost_fn(v, end_v)
priority_fn = lambda g, h: g + w*h
goal_test = lambda v: v == end_v
start_time = time.time()
start_g, start_h = 0, heuristic_fn(start_v)
visited = {start_v: Node(start_g, None)}
queue = [(priority_fn(start_g, start_h), start_g, start_v)]
while queue and (elapsed_time(start_time) < max_time):
_, current_g, current_v = heappop(queue)
if visited[current_v].g < current_g:
continue
if goal_test(current_v):
return retrace_path(visited, current_v)
for next_v in neighbors_fn(current_v):
next_g = current_g + cost_fn(current_v, next_v)
if (next_v not in visited) or (next_g < visited[next_v].g):
visited[next_v] = Node(next_g, current_v)
next_h = heuristic_fn(next_v)
if priority_fn(next_g, next_h) < max_cost:
heappush(queue, (priority_fn(next_g, next_h), next_g, next_v))
return None
示例2: __init__
# 需要导入模块: import utils [as 别名]
# 或者: from utils import INF [as 别名]
def __init__(self):
LTContainer.__init__(self, (+INF,+INF,-INF,-INF))
return
示例3: __init__
# 需要导入模块: import utils [as 别名]
# 或者: from utils import INF [as 别名]
def __init__(self):
LTContainer.__init__(self, (+INF, +INF, -INF, -INF))
return
示例4: forward
# 需要导入模块: import utils [as 别名]
# 或者: from utils import INF [as 别名]
def forward(self, seq, mask):
in_c = seq.size()[1]
seq = torch.transpose(seq, 1, 2) # (N, L, C)
queries = seq
keys = seq
num_heads = self.num_heads
# T_q = T_k = L
Q = F.relu(self.linear_q(seq)) # (N, T_q, C)
K = F.relu(self.linear_k(seq)) # (N, T_k, C)
V = F.relu(self.linear_v(seq)) # (N, T_k, C)
# Split and concat
Q_ = torch.cat(torch.split(Q, in_c // num_heads, dim=2), dim=0) # (h*N, T_q, C/h)
K_ = torch.cat(torch.split(K, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
V_ = torch.cat(torch.split(V, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
# Multiplication
outputs = torch.matmul(Q_, K_.transpose(1, 2)) # (h*N, T_q, T_k)
# Scale
outputs = outputs / (K_.size()[-1] ** 0.5)
# Key Masking
key_masks = mask.repeat(num_heads, 1) # (h*N, T_k)
key_masks = torch.unsqueeze(key_masks, 1) # (h*N, 1, T_k)
key_masks = key_masks.repeat(1, queries.size()[1], 1) # (h*N, T_q, T_k)
paddings = torch.ones_like(outputs) * (-INF) # extremely small value
outputs = torch.where(torch.eq(key_masks, 0), paddings, outputs)
query_masks = mask.repeat(num_heads, 1) # (h*N, T_q)
query_masks = torch.unsqueeze(query_masks, -1) # (h*N, T_q, 1)
query_masks = query_masks.repeat(1, 1, keys.size()[1]).float() # (h*N, T_q, T_k)
att_scores = F.softmax(outputs, dim=-1) * query_masks # (h*N, T_q, T_k)
att_scores = self.dropout(att_scores)
# Weighted sum
x_outputs = torch.matmul(att_scores, V_) # (h*N, T_q, C/h)
# Restore shape
x_outputs = torch.cat(
torch.split(x_outputs, x_outputs.size()[0] // num_heads, dim=0),
dim=2) # (N, T_q, C)
x = torch.transpose(x_outputs, 1, 2) # (N, C, L)
x = self.bn(x, mask)
return x