From 5790e1f24f64dd641f375797013ed0b549b015f0 Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Fri, 15 Mar 2024 05:01:57 +0000 Subject: [PATCH] lib/transform: fix for the strict Keras 3 (TF 2.16) --- lib/transform.py | 72 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 64 insertions(+), 8 deletions(-) diff --git a/lib/transform.py b/lib/transform.py index 98bde9f..de057dc 100644 --- a/lib/transform.py +++ b/lib/transform.py @@ -77,7 +77,7 @@ def build(self, shape): # 1 optional leading batch dimension assert(isinstance(shape, list) and len(shape)==4) sh = shape[0] - if isinstance(sh, tf.TensorShape): + if isinstance(sh, tf.TensorShape) or isinstance(sh, tuple): # TensorShape for TF 2.9.1; tuple for TF 2.16.1 if len(sh)==6: dim = sh[:4] maskshape = dim+(1,1) @@ -85,17 +85,17 @@ def build(self, shape): dim = sh[1:5] maskshape = (1,)+dim+(1,1) else: - raise ValueError(f'unsupported lattice shape {shape}') + raise ValueError(f'unsupported lattice shape {shape}:{type(shape)} with shape[0] {sh}:{type(sh)}') mask = tf.reshape(evenodd_mask(dim)%2, maskshape) if not self.is_odd: mask = 1-mask self.filter = tf.cast(mask, tf.complex128) # 1 for update self.parted = False - elif isinstance(sh, list) and len(sh)==16 and isinstance(sh[0], tf.TensorShape): + elif isinstance(sh, list) and len(sh)==16 and (isinstance(sh[0], tf.TensorShape) or isinstance(sh[0], tuple)): # hypercube partitioned self.parted = True else: - raise ValueError(f'unsupported lattice shape {shape}') + raise ValueError(f'unsupported lattice shape {shape}:{type(shape)} with shape[0] {sh}:{type(sh)}') def transform(self, x): f,l,b = self.compute_change(x) return self.slice_apply(f,x),l,b @@ -196,6 +196,7 @@ def compute_change(self, xin, change_only=False): if nu!=mu: loop_para += power3_trace(stf[nu](stu[nu].adjoint())) # already symmetrical about the link sym_field = self.stack_tensors(loop_para, loop_perp) # 2*(144+9)=306 real numbers per site + # tf.print(f'self.coeff call shape {sym_field.shape}') coeff = self.coeff(sym_field) else: coeff = self.coeff @@ -350,6 +351,10 @@ def __init__(self, transforms, name='TransformChain', **kwargs): if hasattr(t,'invMaxIter'): if self.invMaxIter < t.invMaxIter: self.invMaxIter = t.invMaxIter + def build(self, shape): + # Keras 3 requires explicit building the layers + for f in self.chain: + f.build(shape) def transform(self, x): y = x l = 0.0 @@ -393,6 +398,16 @@ class CoefficientNets(tl.Layer): def __init__(self, sequence, name='CoefficientNets', **kwargs): super().__init__(autocast=False, name=name, **kwargs) self.chain = sequence + def build(self, shape): + # Keras 3 requires explicit building the layers + for nn in self.chain: + # tf.print(f'CoefficientNets {type(nn)} shape {shape}') + nn.build(shape) + # if hasattr(nn, 'output_shape'): + # tf.print(f'CoefficientNets {type(nn)} output_shape {nn.output_shape}') + # if hasattr(nn, 'compute_output_shape'): + # tf.print(f'CoefficientNets {type(nn)} compute_output_shape {nn.compute_output_shape(shape)}') + shape = nn.compute_output_shape(shape) def call(self, x): y = x for nn in self.chain: @@ -418,7 +433,13 @@ def __call__(self, x): for s in range(-self.symmetric_shifts,1+self.symmetric_shifts) for lat in xs] y = tf.stack(xs,axis=-2) + # tf.print(f'SymmetricShifts.__call__ return shape {y.shape}') return y + def build(self, shape): + # tf.print(f'SymmetricShifts.build shape {shape}') + pass + def compute_output_shape(self, shape): + return shape[:-1]+((1+2*self.symmetric_shifts)**4,shape[-1]) class Normalization(tl.Layer): """ @@ -429,9 +450,12 @@ class Normalization(tl.Layer): def __init__(self, epsilon=1e-12, name='Normalization', **kwargs): super().__init__(autocast=False, name=name, **kwargs) self.epsilon = epsilon - def build(self, input_shape): - self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:], initializer='ones', trainable=True) - self.beta = self.add_weight(name='beta', shape=input_shape[-1:], initializer='zeros', trainable=True) + def build(self, shape): + # tf.print(f'Normalization.build shape {shape}') + self.gamma = self.add_weight(name='gamma', shape=shape[-1:], initializer='ones', trainable=True) + self.beta = self.add_weight(name='beta', shape=shape[-1:], initializer='zeros', trainable=True) + def compute_output_shape(self, shape): + return shape def call(self, x): # assuming the batch_dim always comes before TZYXC. lat_rank = 5 @@ -449,7 +473,17 @@ class LocalSelfAttention(tl.MultiHeadAttention): """ def __init__(self, name='LocalSelfAttention', **kwargs): super().__init__(autocast=False, name=name, **kwargs) + def build(self, shape): + # tf.print(f'LocalSelfAttention.build shape {shape}') + n = 1 + for i in shape[:-2]: + n *= i + flat_shape = (n,)+tuple(shape[-2:]) + super().build(flat_shape, flat_shape, flat_shape) + def compute_output_shape(self, shape): + return shape def call(self, x): + # tf.print(f'LocalSelfAttention.call shape {x.shape}') xflat = tf.reshape(x, shape=(-1,)+tuple(x.shape[-2:])) yflat = super().call(query=xflat, value=xflat, key=xflat) return tf.reshape(yflat, x.shape) @@ -459,6 +493,15 @@ class Residue(tl.Layer): def __init__(self, procedure, name='Residue', **kwargs): super().__init__(autocast=False, name=name, **kwargs) self.procedure = procedure + def build(self, shape): + # Keras 3 requires explicit building the layers + if isinstance(self.procedure, (list,tuple)): + for p in self.procedure: + p.build(shape) + else: + self.procedure.build(shape) + def compute_output_shape(self, shape): + return shape def call(self, x): if isinstance(self.procedure, (list,tuple)): y = x @@ -475,8 +518,11 @@ def __init__(self, inner_size, inner_activation='swish', name='LocalFeedForward' self.inner_size = inner_size self.inner_activation = inner_activation def build(self, shape): + # tf.print(f'LocalFeedForward.build shape {shape}') self.dense_in = tl.Dense(units=self.inner_size, activation=self.inner_activation) self.dense_out = tl.Dense(units=shape[-1], activation=None) + def compute_output_shape(self, shape): + return shape def call(self, x): y = self.dense_in(x) z = self.dense_out(y) @@ -486,7 +532,17 @@ class FlattenSiteLocal: def __init__(self, input_local_rank): self.input_local_rank = input_local_rank def __call__(self, x): - return tf.reshape(x, tuple(x.shape[:-self.input_local_rank])+(-1,)) + y = tf.reshape(x, tuple(x.shape[:-self.input_local_rank])+(-1,)) + # tf.print(f'FlattenSiteLocal.__call__ return shape {y.shape}') + return y + def build(self, shape): + # tf.print(f'FlattenSiteLocal.build shape {shape}') + pass + def compute_output_shape(self, shape): + n = 1 + for i in shape[-self.input_local_rank:]: + n *= i + return shape[:-self.input_local_rank]+(n,) if __name__ == '__main__': import sys, os