You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Sorry about the lacking documentation, that's something I was hoping to get to with overall improvements to the use of pretrained checkpoints. For now, the best example is in the GPT2 model, which has the writeCheckpoint(to:name:) method as a prototype for what we'd like to extend to other models.
The CheckpointWriter itself has a reasonably simple interface, taking in a dictionary of String names and Float Tensors corresponding to those names. It then can write out a checkpoint from that.
The process of getting names and tensors from within a model can vary, which is what we're trying to make more consistent and easier to use. The method utilized for the GPT2 model is contained here. The ExportableLayer protocol maps the names of properties within the model to their names within the checkpoint, and the recursivelyObtainTensors() function uses Mirror to iterate over the structure of the model and sublayers to apply this name mapping to the Tensors within. This generates the dictionary that is then passed to the CheckpointWriter.
A similar system could be configured for other models, so we're looking at building a generalized implementation of something like this to make serialization easy. Sorry it's undocumented and a little barebones right now.
BradLarson
changed the title
Notebook
Is there any documentation on how to use the CheckpointWriter API?
May 18, 2020
Is there any resource/notebook on Using CheckpointWriter api inside a training loop for any model in swift?
The text was updated successfully, but these errors were encountered: