-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_following.py
27 lines (21 loc) · 1.03 KB
/
generate_following.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os
import sys
from keras_gpt_2 import load_trained_model_from_checkpoint, get_bpe_from_files, generate
epoch_number = 2
already_trained=True
checkpoint_dir = './training_checkpoints'
checkpoint_path = os.path.join(checkpoint_dir, f"ckpt_{epoch_number}")
model_folder = 'models/117M'
config_path = os.path.join(model_folder, 'hparams.json')
checkpoint_path = os.path.join(model_folder, 'model.ckpt')
encoder_path = os.path.join(model_folder, 'encoder.json')
vocab_path = os.path.join(model_folder, 'vocab.bpe')
print('Load model from checkpoint...')
model = load_trained_model_from_checkpoint(config_path, checkpoint_path, already_trained=already_trained)
print('Load BPE from files...')
bpe = get_bpe_from_files(encoder_path, vocab_path)
print('Generate text...')
output = generate(model, bpe, ['From the day forth, my arm'], length=20, top_k=40)
# If you are using the 117M model and top_k equals to 1, then the result would be:
# "From the day forth, my arm was broken, and I was in a state of pain. I was in a state of pain,"
print(output[0])