Skip to content

Commit

Permalink
Merge pull request #8 from infrub/master
Browse files Browse the repository at this point in the history
Python3対応と構造化と高速化
  • Loading branch information
smorita authored Jul 1, 2019
2 parents 126c7de + f72ed97 commit 707f67c
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 120 deletions.
237 changes: 118 additions & 119 deletions netcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,138 +11,137 @@
import logging
import time
import config
import itertools

BOND_DIMS = []

class Tensor:
class TensorFrame:
"""Tensor class for netcon.
Attributes:
rpn: contraction sequence with reverse polish notation.
bit: bit representation of contracted tensors.
bonds: list of bonds connecting with the outside.
bits: bits representation of contracted tensors.
bonds: list of uncontracted bonds.
is_new: a flag.
"""

def __init__(self,rpn=[],bit=0,bonds=[],cost=0.0,is_new=True):
def __init__(self,rpn=[],bits=0,bonds=[],cost=0.0,is_new=True):
self.rpn = rpn[:]
self.bit = bit
self.bits = bits
self.bonds = bonds
self.cost = cost
self.is_new = is_new

def __str__(self):
return "{0} : bond={1} cost={2:.6e} bit={3} new={4}".format(
self.rpn, list(self.bonds), self.cost, self.bit, self.is_new)


def netcon(tn, bond_dims):
"""Find optimal contraction sequence.
def __repr__(self):
return "TensorFrame({0}, bonds={1}, cost={2:.6e}, bits={3}, is_new={4})".format(
self.rpn, self.bonds, self.cost, self.bits, self.is_new)

Args:
tn: TensorNetwork in tdt.py
bond_dims: List of bond dimensions.
Return:
rpn: Optimal contraction sequence with reverse polish notation.
cost: Total contraction cost.
"""
BOND_DIMS[0:0] = bond_dims
tensor_set = _init(tn)

n = len(tensor_set[0])
xi_min = float(min(BOND_DIMS))
mu_cap = 1.0
mu_old = 0.0

while len(tensor_set[-1])<1:
logging.info("netcon: searching with mu_cap={0:.6e}".format(mu_cap))
mu_next = sys.float_info.max
for c in range(1,n):
for d1 in range((c+1)/2):
d2 = c-d1-1
n1 = len(tensor_set[d1])
n2 = len(tensor_set[d2])
for i1 in range(n1):
i2_start = i1+1 if d1==d2 else 0
for i2 in range(i2_start, n2):
t1 = tensor_set[d1][i1]
t2 = tensor_set[d2][i2]

if _is_disjoint(t1,t2): continue
if _is_overlap(t1,t2): continue

mu = _get_cost(t1,t2)
mu_0 = 0.0 if (t1.is_new or t2.is_new) else mu_old

if (mu > mu_cap) and (mu < mu_next): mu_next = mu
if (mu > mu_0) and (mu <= mu_cap):
t_new = _contract(t1,t2)
is_find = False
for i,t_old in enumerate(tensor_set[c]):
if t_new.bit == t_old.bit:
if t_new.cost < t_old.cost:
tensor_set[c][i] = t_new
is_find = True
break
if not is_find: tensor_set[c].append(t_new)
mu_old = mu_cap
mu_cap = max(mu_next, mu_cap*xi_min)
for s in tensor_set:
for t in s: t.is_new = False

logging.debug("netcon: tensor_num=" + str([ len(s) for s in tensor_set]))

t_final = tensor_set[-1][0]
return t_final.rpn, t_final.cost


def _init(tn):
"""Initialize a set of tensors from tdt tensor-network."""
tensor_set = [[] for t in tn.tensors]
for t in tn.tensors:
rpn = t.name
bit = 0
for i in rpn:
if i>=0: bit += (1<<i)
bonds = frozenset(t.bonds)
cost = 0.0
tensor_set[0].append(Tensor(rpn,bit,bonds,cost))
return tensor_set


def _get_cost(t1,t2):
"""Get the cost of contraction of two tensors."""
cost = 1.0
for b in (t1.bonds | t2.bonds):
cost *= BOND_DIMS[b]
cost += t1.cost + t2.cost
return cost


def _contract(t1,t2):
"""Return a contracted tensor"""
assert (not _is_disjoint(t1,t2))
rpn = t1.rpn + t2.rpn + [-1]
bit = t1.bit ^ t2.bit # XOR
bonds = frozenset(t1.bonds ^ t2.bonds)
cost = _get_cost(t1,t2)
return Tensor(rpn,bit,bonds,cost)


def _is_disjoint(t1,t2):
"""Check if two tensors are disjoint."""
return (t1.bonds).isdisjoint(t2.bonds)


def _is_overlap(t1,t2):
"""Check if two tensors have the same basic tensor."""
return (t1.bit & t2.bit)>0


def _print_tset(tensor_set):
"""Print tensor_set. (for debug)"""
for level in range(len(tensor_set)):
for i,t in enumerate(tensor_set[level]):
print level,i,t
def __str__(self):
return "{0} : bonds={1} cost={2:.6e} bits={3} new={4}".format(
self.rpn, self.bonds, self.cost, self.bits, self.is_new)


class NetconOptimizer:
def __init__(self, prime_tensors, bond_dims):
self.prime_tensors = prime_tensors
self.BOND_DIMS = bond_dims[:]

def optimize(self):
"""Find optimal contraction sequence.
Args:
tn: TensorNetwork in tdt.py
bond_dims: List of bond dimensions.
Return:
rpn: Optimal contraction sequence with reverse polish notation.
cost: Total contraction cost.
"""
tensordict_of_size = self.init_tensordict_of_size()

n = len(self.prime_tensors)
xi_min = float(min(self.BOND_DIMS))
mu_cap = 1.0
prev_mu_cap = 0.0 #>=0

while len(tensordict_of_size[-1])<1:
logging.info("netcon: searching with mu_cap={0:.6e}".format(mu_cap))
next_mu_cap = sys.float_info.max
for c in range(2,n+1):
for d1 in range(1,c//2+1):
d2 = c-d1
t1_t2_iterator = itertools.combinations(tensordict_of_size[d1].values(), 2) if d1==d2 else itertools.product(tensordict_of_size[d1].values(), tensordict_of_size[d2].values())
for t1, t2 in t1_t2_iterator:
if self.are_overlap(t1,t2): continue
if self.are_direct_product(t1,t2): continue

cost = self.get_contracting_cost(t1,t2)
bits = t1.bits ^ t2.bits

if next_mu_cap <= cost:
pass
elif mu_cap < cost:
next_mu_cap = cost
elif t1.is_new or t2.is_new or prev_mu_cap < cost:
t_old = tensordict_of_size[c].get(bits)
if t_old is None or cost < t_old.cost:
tensordict_of_size[c][bits] = self.contract(t1,t2)
prev_mu_cap = mu_cap
mu_cap = max(next_mu_cap, mu_cap*xi_min)
for s in tensordict_of_size:
for t in s.values(): t.is_new = False

logging.debug("netcon: tensor_num=" + str([ len(s) for s in tensordict_of_size]))

t_final = tensordict_of_size[-1][(1<<n)-1]
return t_final.rpn, t_final.cost


def init_tensordict_of_size(self):
"""tensordict_of_size[k][bits] == calculated lowest-cost tensor which is contraction of k+1 prime tensors and whose bits == bits"""
tensordict_of_size = [{} for size in range(len(self.prime_tensors)+1)]
for t in self.prime_tensors:
rpn = t.name
bits = 0
for i in rpn:
if i>=0: bits += (1<<i)
bonds = frozenset(t.bonds)
cost = 0.0
tensordict_of_size[1].update({bits:TensorFrame(rpn,bits,bonds,cost)})
return tensordict_of_size


def get_contracting_cost(self,t1,t2):
"""Get the cost of contraction of two tensors."""
cost = 1.0
for b in (t1.bonds | t2.bonds):
cost *= self.BOND_DIMS[b]
cost += t1.cost + t2.cost
return cost


def contract(self,t1,t2):
"""Return a contracted tensor"""
assert (not self.are_direct_product(t1,t2))
rpn = t1.rpn + t2.rpn + [-1]
bits = t1.bits ^ t2.bits # XOR
bonds = frozenset(t1.bonds ^ t2.bonds)
cost = self.get_contracting_cost(t1,t2)
return TensorFrame(rpn,bits,bonds,cost)


def are_direct_product(self,t1,t2):
"""Check if two tensors are disjoint."""
return (t1.bonds).isdisjoint(t2.bonds)


def are_overlap(self,t1,t2):
"""Check if two tensors have the same basic tensor."""
return (t1.bits & t2.bits)>0


def print_tset(self,tensors_of_size):
"""Print tensors_of_size. (for debug)"""
for level in range(len(tensors_of_size)):
for i,t in enumerate(tensors_of_size[level]):
print(level,i,t)

5 changes: 4 additions & 1 deletion tdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self,name=None,bonds=[]):
self.name = [name]
self.bonds = bonds[:]

def __repr__(self):
return "Tensor(" + str(self.name) + ", " + str(self.bonds) +")"

def __str__(self):
return str(self.name) + ", " + str(self.bonds)

Expand Down Expand Up @@ -449,7 +452,7 @@ def parse_args():
logging.basicConfig(format="%(levelname)s:%(message)s", level=config.LOGGING_LEVEL)

tn.output_log("input")
rpn, cpu = netcon.netcon(tn, BOND_DIMS)
rpn, cpu = netcon.NetconOptimizer(tn.tensors, BOND_DIMS).optimize()
mem = get_memory(tn, rpn)

TENSOR_MATH_NAMES = TENSOR_NAMES[:]
Expand Down

0 comments on commit 707f67c

Please sign in to comment.