Skip to content

Commit

Permalink
lib/transform: fix for the strict Keras 3 (TF 2.16)
Browse files Browse the repository at this point in the history
  • Loading branch information
jxy committed Mar 15, 2024
1 parent ba4ba46 commit 5790e1f
Showing 1 changed file with 64 additions and 8 deletions.
72 changes: 64 additions & 8 deletions lib/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,25 @@ def build(self, shape):
# 1 optional leading batch dimension
assert(isinstance(shape, list) and len(shape)==4)
sh = shape[0]
if isinstance(sh, tf.TensorShape):
if isinstance(sh, tf.TensorShape) or isinstance(sh, tuple): # TensorShape for TF 2.9.1; tuple for TF 2.16.1
if len(sh)==6:
dim = sh[:4]
maskshape = dim+(1,1)
elif len(sh)==7:
dim = sh[1:5]
maskshape = (1,)+dim+(1,1)
else:
raise ValueError(f'unsupported lattice shape {shape}')
raise ValueError(f'unsupported lattice shape {shape}:{type(shape)} with shape[0] {sh}:{type(sh)}')
mask = tf.reshape(evenodd_mask(dim)%2, maskshape)
if not self.is_odd:
mask = 1-mask
self.filter = tf.cast(mask, tf.complex128) # 1 for update
self.parted = False
elif isinstance(sh, list) and len(sh)==16 and isinstance(sh[0], tf.TensorShape):
elif isinstance(sh, list) and len(sh)==16 and (isinstance(sh[0], tf.TensorShape) or isinstance(sh[0], tuple)):
# hypercube partitioned
self.parted = True
else:
raise ValueError(f'unsupported lattice shape {shape}')
raise ValueError(f'unsupported lattice shape {shape}:{type(shape)} with shape[0] {sh}:{type(sh)}')
def transform(self, x):
f,l,b = self.compute_change(x)
return self.slice_apply(f,x),l,b
Expand Down Expand Up @@ -196,6 +196,7 @@ def compute_change(self, xin, change_only=False):
if nu!=mu:
loop_para += power3_trace(stf[nu](stu[nu].adjoint())) # already symmetrical about the link
sym_field = self.stack_tensors(loop_para, loop_perp) # 2*(144+9)=306 real numbers per site
# tf.print(f'self.coeff call shape {sym_field.shape}')
coeff = self.coeff(sym_field)
else:
coeff = self.coeff
Expand Down Expand Up @@ -350,6 +351,10 @@ def __init__(self, transforms, name='TransformChain', **kwargs):
if hasattr(t,'invMaxIter'):
if self.invMaxIter < t.invMaxIter:
self.invMaxIter = t.invMaxIter
def build(self, shape):
# Keras 3 requires explicit building the layers
for f in self.chain:
f.build(shape)
def transform(self, x):
y = x
l = 0.0
Expand Down Expand Up @@ -393,6 +398,16 @@ class CoefficientNets(tl.Layer):
def __init__(self, sequence, name='CoefficientNets', **kwargs):
super().__init__(autocast=False, name=name, **kwargs)
self.chain = sequence
def build(self, shape):
# Keras 3 requires explicit building the layers
for nn in self.chain:
# tf.print(f'CoefficientNets {type(nn)} shape {shape}')
nn.build(shape)
# if hasattr(nn, 'output_shape'):
# tf.print(f'CoefficientNets {type(nn)} output_shape {nn.output_shape}')
# if hasattr(nn, 'compute_output_shape'):
# tf.print(f'CoefficientNets {type(nn)} compute_output_shape {nn.compute_output_shape(shape)}')
shape = nn.compute_output_shape(shape)
def call(self, x):
y = x
for nn in self.chain:
Expand All @@ -418,7 +433,13 @@ def __call__(self, x):
for s in range(-self.symmetric_shifts,1+self.symmetric_shifts)
for lat in xs]
y = tf.stack(xs,axis=-2)
# tf.print(f'SymmetricShifts.__call__ return shape {y.shape}')
return y
def build(self, shape):
# tf.print(f'SymmetricShifts.build shape {shape}')
pass
def compute_output_shape(self, shape):
return shape[:-1]+((1+2*self.symmetric_shifts)**4,shape[-1])

class Normalization(tl.Layer):
"""
Expand All @@ -429,9 +450,12 @@ class Normalization(tl.Layer):
def __init__(self, epsilon=1e-12, name='Normalization', **kwargs):
super().__init__(autocast=False, name=name, **kwargs)
self.epsilon = epsilon
def build(self, input_shape):
self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:], initializer='ones', trainable=True)
self.beta = self.add_weight(name='beta', shape=input_shape[-1:], initializer='zeros', trainable=True)
def build(self, shape):
# tf.print(f'Normalization.build shape {shape}')
self.gamma = self.add_weight(name='gamma', shape=shape[-1:], initializer='ones', trainable=True)
self.beta = self.add_weight(name='beta', shape=shape[-1:], initializer='zeros', trainable=True)
def compute_output_shape(self, shape):
return shape
def call(self, x):
# assuming the batch_dim always comes before TZYXC.
lat_rank = 5
Expand All @@ -449,7 +473,17 @@ class LocalSelfAttention(tl.MultiHeadAttention):
"""
def __init__(self, name='LocalSelfAttention', **kwargs):
super().__init__(autocast=False, name=name, **kwargs)
def build(self, shape):
# tf.print(f'LocalSelfAttention.build shape {shape}')
n = 1
for i in shape[:-2]:
n *= i
flat_shape = (n,)+tuple(shape[-2:])
super().build(flat_shape, flat_shape, flat_shape)
def compute_output_shape(self, shape):
return shape
def call(self, x):
# tf.print(f'LocalSelfAttention.call shape {x.shape}')
xflat = tf.reshape(x, shape=(-1,)+tuple(x.shape[-2:]))
yflat = super().call(query=xflat, value=xflat, key=xflat)
return tf.reshape(yflat, x.shape)
Expand All @@ -459,6 +493,15 @@ class Residue(tl.Layer):
def __init__(self, procedure, name='Residue', **kwargs):
super().__init__(autocast=False, name=name, **kwargs)
self.procedure = procedure
def build(self, shape):
# Keras 3 requires explicit building the layers
if isinstance(self.procedure, (list,tuple)):
for p in self.procedure:
p.build(shape)
else:
self.procedure.build(shape)
def compute_output_shape(self, shape):
return shape
def call(self, x):
if isinstance(self.procedure, (list,tuple)):
y = x
Expand All @@ -475,8 +518,11 @@ def __init__(self, inner_size, inner_activation='swish', name='LocalFeedForward'
self.inner_size = inner_size
self.inner_activation = inner_activation
def build(self, shape):
# tf.print(f'LocalFeedForward.build shape {shape}')
self.dense_in = tl.Dense(units=self.inner_size, activation=self.inner_activation)
self.dense_out = tl.Dense(units=shape[-1], activation=None)
def compute_output_shape(self, shape):
return shape
def call(self, x):
y = self.dense_in(x)
z = self.dense_out(y)
Expand All @@ -486,7 +532,17 @@ class FlattenSiteLocal:
def __init__(self, input_local_rank):
self.input_local_rank = input_local_rank
def __call__(self, x):
return tf.reshape(x, tuple(x.shape[:-self.input_local_rank])+(-1,))
y = tf.reshape(x, tuple(x.shape[:-self.input_local_rank])+(-1,))
# tf.print(f'FlattenSiteLocal.__call__ return shape {y.shape}')
return y
def build(self, shape):
# tf.print(f'FlattenSiteLocal.build shape {shape}')
pass
def compute_output_shape(self, shape):
n = 1
for i in shape[-self.input_local_rank:]:
n *= i
return shape[:-self.input_local_rank]+(n,)

if __name__ == '__main__':
import sys, os
Expand Down

0 comments on commit 5790e1f

Please sign in to comment.