From 8f7de0ecc7fb7cc4e643f4b44e860805ed29b63a Mon Sep 17 00:00:00 2001 From: gully Date: Fri, 6 Oct 2023 15:24:55 -0500 Subject: [PATCH] uncomitted changes --- src/blase/emulator.py | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/src/blase/emulator.py b/src/blase/emulator.py index 644c8df..46a5eda 100644 --- a/src/blase/emulator.py +++ b/src/blase/emulator.py @@ -9,6 +9,7 @@ from torch import nn import numpy as np from scipy.signal import find_peaks, peak_prominences, peak_widths +from scipy.special import voigt_profile import torch.optim as optim from tqdm import trange from torch.special import erfc @@ -329,15 +330,44 @@ def optimize(self): raise NotImplementedError - def animate(self, old_state_dict): + def animate(self, old_state_dict, size=1.5): """Animate the model from a previous state to the current state""" try: import manim except ImportError: print("Manim is required for the .animate() feature, but it is not installed. Please install manim and try again.") - return - - raise NotImplementedError + return None + + ## Manim requires a scene class, so we'll make one here + class ArgMinExample(manim.Scene, size=0.9): + def construct(self, size): + ax = manim.Axes( + x_range=[-10, 10], y_range=[0, 1, 0.25], axis_config={"include_tip": False} + ) + labels = ax.get_axis_labels(x_label=manim.Tex(r"$\lambda$"), + y_label=manim.Tex(r"$f(\lambda)$")) + + graph1 = ax.plot(lambda x: 1 - 2*voigt_profile(x, 1.0, 0), color=manim.MAROON) + graph2 = ax.plot(lambda x: 1 - 2*voigt_profile(x, 1.0, size), color=manim.MAROON) + + # Plot noisy data with manim: + x = np.linspace(-10, 10, 100) + y = 1 - 2*voigt_profile(x, 1.0, size) + y += np.random.normal(0, 0.03, y.shape) + coords = np.vstack((x,y)).T + + dots = manim.VGroup(*[manim.Dot().move_to(ax.c2p(coord[0],coord[1])) for coord in coords]) + self.add(dots) + + + self.add(ax, labels) + self.play(manim.Create(graph1)) + self.play(manim.Transform(graph1, graph2)) + self.play(manim.FadeOut(graph2)) + + scene = ArgMinExample(size=size) + scene.render(preview=False) # That's it! + return scene.renderer.file_writer.movie_file_path