Skip to content

Commit

Permalink
Support sampling on other state.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 27, 2023
1 parent e7e584a commit b375a8a
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions tetragono/tetragono/sampling_lattice/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def gradient_descent(
direct_sampling_cut_dimension=4,
sampling_configurations=np.zeros(0, dtype=np.int64),
sweep_hopping_hamiltonians=None,
sampling_state=None,
# About subspace
restrict_subspace=None,
# About gradient method
Expand Down Expand Up @@ -324,6 +325,9 @@ def gradient_descent(
classical_energy=classical_energy,
)

if sampling_state is None:
sampling_state = state

# Main loop
with SignalHandler(signal.SIGINT) as sigint_handler:
for grad_step in range(grad_total_step):
Expand All @@ -338,15 +342,16 @@ def gradient_descent(
"hopping_hamiltonians")(state)
else:
hopping_hamiltonians = None
sampling = SweepSampling(state, configuration_cut_dimension, restrict, hopping_hamiltonians)
sampling = SweepSampling(sampling_state, configuration_cut_dimension, restrict,
hopping_hamiltonians)
sampling_total_step = sampling_total_step
# Initial sweep configuration
sampling.configuration.import_configuration(sampling_configurations)
elif sampling_method == "ergodic":
sampling = ErgodicSampling(state, configuration_cut_dimension, restrict)
sampling = ErgodicSampling(sampling_state, configuration_cut_dimension, restrict)
sampling_total_step = sampling.total_step
elif sampling_method == "direct":
sampling = DirectSampling(state, configuration_cut_dimension, restrict,
sampling = DirectSampling(sampling_state, configuration_cut_dimension, restrict,
direct_sampling_cut_dimension)
sampling_total_step = sampling_total_step
else:
Expand All @@ -355,6 +360,11 @@ def gradient_descent(
for sampling_step in range(sampling_total_step):
if sampling_step % mpi_size == mpi_rank:
possibility, configuration = sampling()
if sampling_state is not state:
old_configuration = configuration
configuration = Configuration(state)
for [l1, l2, orbit], _ in state.physics_edges:
configuration[l1, l2, orbit] = old_configuration[l1, l2, orbit]
observer(possibility, configuration)
if need_energy_observer:
configuration_pool.append((possibility, configuration))
Expand Down

0 comments on commit b375a8a

Please sign in to comment.