This is the official repo for our paper MacLaSa: Multi-Aspect Controllable Text Generation via Efficient Sampling from Compact Latent Space
An overview of MacLaSa. Left: Build latent space for MacLaSa. We utilize the VAE framework with two additional losses to build a compact latent space. Top Right: Formulate joint EBMs. We formulate the latent-space EBMs of latent representation and attribute to facilitate the plug in of multiple attribute constraint classifiers. Bottom Right Sample with ODE. We adopt a fast ODE-based sampler to perform efficient sampling from the EBMs, and feed samples to the VAE decoder to output desired multi-aspect sentences.
-
Download our training data from this link. Unzip them and put them under the
data
directory. -
Download the discriminator checkpoints (Sentiment Discriminator, Topic Discriminator) used to evaluate multi-aspect control. Unzip them and put them under the
model
folder.
- Training of VAE. It is used to build compact latent space
python main.py --checkpoint_dir ./model \
--train_senti_data_file ./data/imdb/train.csv \
--train_topic_data_file ./data/agnews/train.csv \
--eval_senti_data_file ./data/imdb/test.csv \
--eval_topic_data_file ./data/agnews/test.csv \
--latent_loss_weight 1.0 \
--gap_loss_weight 1.0 \
--learning_rate 1e-4 \
--num_train_epochs 50
- Training of GAN. It is used to simulate prior distribution p(z)
python train_classifier.py --train_cls_gan gan \
--train_senti_data_file ./data/imdb/train1w.csv \
--train_topic_data_file ./data/agnews/train1w.csv \
--eval_senti_data_file ./data/imdb/test1k.csv \
--eval_topic_data_file ./data/agnews/test1k.csv \
--checkpoint_dir ./model/CKPT_NAME \
--output_dir ./model/CKPT_NAME \
--learning_rate 1e-5 \
--num_train_epochs 5
- Training of attribute classifiers. They are used to guide complex multi-aspect control
(1) for sentiment classifier
python train_classifier.py --train_cls_gan cls \
--save_step 1 \
--n_classes 2 \
--train_senti_data_file ./data/imdb/train1w.csv \
--eval_senti_data_file ./data/imdb/test1k.csv \
--checkpoint_dir ./model/CKPT_NAME \
--output_dir ./model/CKPT_NAME \
--learning_rate 1e-5
(2) for topic classifier
python train_classifier.py --train_cls_gan cls \
--save_step 2 \
--n_classes 4 \
--train_topic_data_file ./data/agnews/train1w.csv \
--eval_topic_data_file ./data/agnews/test1k.csv \
--checkpoint_dir ./model/CKPT_NAME \
--output_dir ./model/CKPT_NAME \
--learning_rate 1e-4
- Generation. Conduct conditional generation to obtain desired multi-apsect senteces.
python generate.py --checkpoint_dir ./model/CKPT_NAME --output_dir ./outputs/
- Evaluation. After generation, perform automatic evaluation to assess the quality of generated texts.
python evaluate.py --output_save_dir ./outputs/OUTPUT_NAME/ \
--sentiment_discriminator_save_dir ./model/ \
--topic_discriminator_save_dir ./model/
If you find the code helpful, please cite our paper:
@article{ding2023maclasa,
title={MacLaSa: Multi-Aspect Controllable Text Generation via Efficient Sampling from Compact Latent Space},
author={Ding, Hanxing and Pang, Liang and Wei, Zihao and Shen, Huawei and Cheng, Xueqi and Chua, Tat-Seng},
journal={arXiv preprint arXiv:2305.12785},
year={2023}
}