diff --git a/RecommenderSystems/pnn/README.md b/RecommenderSystems/pnn/README.md new file mode 100644 index 000000000..a6d77f113 --- /dev/null +++ b/RecommenderSystems/pnn/README.md @@ -0,0 +1,169 @@ +# PNN +[PNN](https://arxiv.org/pdf/1611.00144.pdf) is a Neural Network with a product layer to capture interactive patterns between interfield categories, and further fully connected layers to explore high-order feature interactions for CTR prediction. Its model structure is as follows. Based on this structure, this project uses OneFlow distributed deep learning framework to realize training the model in graph mode on the Criteo data set. +

+ image +

+ +## Directory description + +```txt +. +├── pnn_train_eval.py # OneFlow PNN train/val/test scripts with OneEmbedding module +├── README.md # Documentation +├── tools +│ ├── pnn_parquet.scala # Read Criteo Kaggle data and export it as parquet data format +│ └── launch_spark.sh # Spark launching shell script +│ └── split_criteo_kaggle.py # Split criteo kaggle dataset to train\val\test set +├── train_pnn.sh # PNN training shell script +``` + +## Arguments description +| Argument Name | Argument Explanation | Default Value | +| -------------------------- | ------------------------------------------------------------ | ------------------------ | +| data_dir | the data file directory | *Required Argument* | +| num_train_samples | the number of train samples | *Required Argument* | +| num_val_samples | the number of validation samples | *Required Argument* | +| num_test_samples | the number of test samples | *Required Argument* | +| model_load_dir | model loading directory | None | +| model_save_dir | model saving directory | None | +| save_initial_model | save initial model parameters or not | False | +| save_model_after_each_eval | save model after each eval or not | False | +| disable_fusedmlp | disable fused MLP or not | True | +| embedding_vec_size | embedding vector size | 16 | +| dnn | dnn hidden units number | 1000,1000 | +| net_dropout | number of minibatch training interations | 0.2 | +| embedding_vec_size | embedding vector size | 16 | +| embedding_regularizer | embedding layer regularization rate | 1.0e-05 | +| net_regularizer | net regularization rate | 0.0 | +| max_gradient_norm | max norm of the gradients | 10.0 | +| learning_rate | initial learning rate | 0.001 | +| batch_size | training/evaluation batch size | 10000 | +| train_batches | the maximum number of training batches | 35000 | +| loss_print_interval | interval of printing loss | 100 | +| patience | Number of epochs with no improvement after which learning rate will be reduced | 2 | +| min_delta | threshold for measuring the new optimum, to only focus on significant changes | 1.0e-6 | +| table_size_array | embedding table size array for sparse fields | *Required Argument* | +| persistent_path | path for persistent kv store of embedding | *Required Argument* | +| store_type | OneEmbeddig persistent kv store type: `device_mem`, `cached_host_mem` or `cached_ssd` | `cached_host_mem` | +| cache_memory_budget_mb | size of cache memory budget on each device in megabytes when `store_type` is `cached_host_mem` or `cached_ssd` | 1024 | +| amp | enable Automatic Mixed Precision(AMP) training or not | False | +| loss_scale_policy | loss scale policy for AMP training: `static` or `dynamic` | `static` | +| use_inner | use inner product or not | True | +| use_outter | use outter product or not | False | +| disable_early_stop | disable early stop or not | False | + +#### Early Stop Schema + +The model is evaluated at the end of every epoch. At the end of each epoch, if the early stopping criterion is met, the training process will be stopped. + +The monitor used for the early stop is `val_auc - val_log_loss`. The mode of the early stop is `max`. You could tune `patience` and `min_delta` as needed. + +If you want to disable early stopping, simply add `--disable_early_stop` in the [train_pnn.sh](https://github.com/Oneflow-Inc/models/blob/dev_pnn_pr/RecommenderSystems/pnn/train_pnn.sh). + +## Getting Started + +A hands-on guide to train a PNN model. + +### Environment + +1. Install OneFlow by following the steps in [OneFlow Installation Guide](https://github.com/Oneflow-Inc/oneflow#install-oneflow) or use the command line below. + + ```shell + python3 -m pip install --pre oneflow -f https://staging.oneflow.info/branch/master/cu102 + ``` + +2. Install all other dependencies listed below. + + ```json + psutil + petastorm + pandas + sklearn + ``` + +### Dataset + +**Note**: + +According to [the PNN paper](https://arxiv.org/pdf/1611.00144.pdf), we treat both categorical and continuous features as sparse features. + +> χ may include categorical fields (e.g., gender, location) and continuous fields (e.g., age). Each categorical field is represented as a vec- tor of one-hot encoding, and each continuous field is repre- sented as the value itself, or a vector of one-hot encoding after discretization. + +1. Download the [Criteo Kaggle dataset](https://www.kaggle.com/c/criteo-display-ad-challenge) and then split it using [split_criteo_kaggle.py](https://github.com/Oneflow-Inc/models/blob/dev_pnn_pr/RecommenderSystems/pnn/tools/split_criteo_kaggle.py). + + Note: Same as [the PNN_Criteo_x4_001 experiment](https://github.com/openbenchmark/BARS/tree/master/ctr_prediction/benchmarks/PNN/PNN_criteo_x4_001) in FuxiCTR, only train.txt is used. Also, the dataset is randomly spllitted into 8:1:1 as training set, validation set and test set. The dataset is splitted using StratifiedKFold in sklearn. + + ```shell + python3 split_criteo_kaggle.py --input_dir=/path/to/your/criteo_kaggle --output_dir=/path/to/your/output/dir + ``` + +2. Download spark from https://spark.apache.org/downloads.html and then uncompress the tar file into the directory where you want to install Spark. Ensure the `SPARK_HOME` environment variable points to the directory where the spark is. + +3. launch a spark shell using [launch_spark.sh](https://github.com/Oneflow-Inc/models/blob/dev_pnn_pr/RecommenderSystems/pnn/tools/launch_spark.sh). + + - Modify the SPARK_LOCAL_DIRS as needed + + ```shell + export SPARK_LOCAL_DIRS=/path/to/your/spark/ + ``` + + - Run `bash launch_spark.sh` + +4. load [pnn_parquet.scala](https://github.com/Oneflow-Inc/models/blob/dev_pnn_pr/RecommenderSystems/pnn/tools/pnn_parquet.scala) to your spark shell by `:load pnn_parquet.scala`. + +5. call the `makePNNDataset(srcDir: String, dstDir:String)` function to generate the dataset. + + ```shell + makePNNDataset("/path/to/your/src_dir", "/path/to/your/dst_dir") + ``` + + After generating parquet dataset, dataset information will also be printed. It contains the information about the number of samples and table size array, which is needed when training. + + ```txt + train samples = 36672493 + validation samples = 4584062 + test samples = 4584062 + table size array: + 649,9364,14746,490,476707,11618,4142,1373,7275,13,169,407,1376 + 1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572 + ``` + + +### Start Training by Oneflow + +1. Modify the [train_pnn.sh](https://github.com/Oneflow-Inc/models/blob/dev_pnn_pr/RecommenderSystems/pnn/train_pnn.sh) as needed. + + ```shell + #!/bin/bash + DEVICE_NUM_PER_NODE=1 + DATA_DIR=/path/to/pnn_parquet + PERSISTENT_PATH=/path/to/persistent + MODEL_SAVE_DIR=/path/to/model/save/dir + + python3 -m oneflow.distributed.launch \ + --nproc_per_node $DEVICE_NUM_PER_NODE \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 127.0.0.1 \ + pnn_train_eval.py \ + --data_dir $DATA_DIR \ + --persistent_path $PERSISTENT_PATH \ + --table_size_array "43, 98, 121, 41, 219, 112, 79, 68, 91, 5, 26, 36, 70, 1447, 554, 157461, 117683, 305, 17, 11878, 629, 4, 39504, 5128, 156729, 3175, 27, 11070, 149083, 11, 4542, 1996, 4, 154737, 17, 16, 52989, 81, 40882" \ + --store_type 'cached_host_mem' \ + --cache_memory_budget_mb 1024 \ + --batch_size 10000 \ + --train_batches 75000 \ + --loss_print_interval 100 \ + --dnn "1000,1000" \ + --net_dropout 0.2 \ + --learning_rate 0.001 \ + --embedding_vec_size 16 \ + --num_train_samples 36672493 \ + --num_val_samples 4584062 \ + --num_test_samples 4584062 \ + --model_save_dir $MODEL_SAVE_DIR \ + --save_best_model + ``` + +2. train a PNN model by `bash train_pnn.sh`. + diff --git a/RecommenderSystems/pnn/pnn_train_eval.py b/RecommenderSystems/pnn/pnn_train_eval.py new file mode 100644 index 000000000..86084d29d --- /dev/null +++ b/RecommenderSystems/pnn/pnn_train_eval.py @@ -0,0 +1,730 @@ +import argparse +import os +import sys +import glob +import time +import math +import numpy as np +import psutil +import oneflow as flow +import oneflow.nn as nn +from petastorm.reader import make_batch_reader + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + + +def get_args(print_args=True): + def int_list(x): + return list(map(int, x.split(","))) + + def str_list(x): + return list(map(str, x.split(","))) + + parser = argparse.ArgumentParser() + + parser.add_argument("--data_dir", type=str, required=True) + parser.add_argument( + "--num_train_samples", type=int, required=True, help="the number of training samples", + ) + parser.add_argument( + "--num_val_samples", type=int, required=True, help="the number of validation samples", + ) + parser.add_argument( + "--num_test_samples", type=int, required=True, help="the number of test samples" + ) + parser.add_argument("--model_load_dir", type=str, default=None, help="model loading directory") + parser.add_argument("--model_save_dir", type=str, default=None, help="model saving directory") + parser.add_argument( + "--save_initial_model", action="store_true", help="save initial model parameters or not.", + ) + parser.add_argument( + "--save_model_after_each_eval", action="store_true", help="save model after each eval.", + ) + + parser.add_argument("--embedding_vec_size", type=int, default=16) + parser.add_argument("--dnn", type=int_list, default="1000,1000,1000,1000,1000") + parser.add_argument("--net_dropout", type=float, default=0.2) + + parser.add_argument("--lr_factor", type=float, default=0.1) + parser.add_argument("--min_lr", type=float, default=1.0e-6) + parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate") + + parser.add_argument( + "--batch_size", type=int, default=10000, help="training/evaluation batch size" + ) + parser.add_argument( + "--train_batches", type=int, default=75000, help="the maximum number of training batches", + ) + parser.add_argument("--loss_print_interval", type=int, default=100, help="") + + parser.add_argument( + "--patience", + type=int, + default=2, + help="number of epochs with no improvement after which learning rate will be reduced", + ) + parser.add_argument( + "--min_delta", + type=float, + default=1.0e-6, + help="threshold for measuring the new optimum, to only focus on significant changes", + ) + + parser.add_argument( + "--table_size_array", + default="649,9364,14746,490,476707,11618,4142,1373,7275,13,169,407,1376,1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572", + type=int_list, + help="Embedding table size array for sparse fields", + required=False, + ) + parser.add_argument( + "--persistent_path", type=str, required=True, help="path for persistent kv store", + ) + parser.add_argument( + "--store_type", + type=str, + default="cached_host_mem", + help="OneEmbeddig persistent kv store type: device_mem, cached_host_mem, cached_ssd", + ) + parser.add_argument( + "--cache_memory_budget_mb", + type=int, + default=1024, + help="size of cache memory budget on each device in megabytes when store_type is cached_host_mem or cached_ssd", + ) + + parser.add_argument( + "--amp", action="store_true", help="enable Automatic Mixed Precision(AMP) training or not", + ) + parser.add_argument("--loss_scale_policy", type=str, default="static", help="static or dynamic") + + parser.add_argument( + "--disable_early_stop", action="store_true", help="enable early stop or not" + ) + parser.add_argument("--save_best_model", action="store_true", help="save best model or not") + parser.add_argument("--use_inner", type=bool, default=True, help="Use inner_product_layer") + parser.add_argument("--use_outter", type=bool, default=False, help="Use outter_product_layer") + + args = parser.parse_args() + + if print_args and flow.env.get_rank() == 0: + _print_args(args) + return args + + +def _print_args(args): + """Print arguments.""" + print("------------------------ arguments ------------------------", flush=True) + str_list = [] + for arg in vars(args): + dots = "." * (48 - len(arg)) + str_list.append(" {} {} {}".format(arg, dots, getattr(args, arg))) + for arg in sorted(str_list, key=lambda x: x.lower()): + print(arg, flush=True) + print("-------------------- end of arguments ---------------------", flush=True) + + +num_dense_fields = 13 +num_sparse_fields = 26 + + +class PNNDataReader(object): + """A context manager that manages the creation and termination of a + :class:`petastorm.Reader`. + """ + + def __init__( + self, + parquet_file_url_list, + batch_size, + num_epochs=1, + shuffle_row_groups=True, + shard_seed=2020, + shard_count=1, + cur_shard=0, + ): + self.parquet_file_url_list = parquet_file_url_list + self.batch_size = batch_size + self.num_epochs = num_epochs + self.shuffle_row_groups = shuffle_row_groups + self.shard_seed = shard_seed + self.shard_count = shard_count + self.cur_shard = cur_shard + + fields = ["Label"] + fields += [f"I{i+1}" for i in range(num_dense_fields)] + fields += [f"C{i+1}" for i in range(num_sparse_fields)] + self.fields = fields + self.num_fields = len(fields) + + def __enter__(self): + self.reader = make_batch_reader( + self.parquet_file_url_list, + workers_count=2, + shuffle_row_groups=self.shuffle_row_groups, + num_epochs=self.num_epochs, + shard_seed=self.shard_seed, + shard_count=self.shard_count, + cur_shard=self.cur_shard, + ) + self.loader = self.get_batches(self.reader) + return self.loader + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.reader.stop() + self.reader.join() + + def get_batches(self, reader, batch_size=None): + if batch_size is None: + batch_size = self.batch_size + + tail = None + + for rg in reader: + rgdict = rg._asdict() + rglist = [rgdict[field] for field in self.fields] + pos = 0 + if tail is not None: + pos = batch_size - len(tail[0]) + tail = list( + [ + np.concatenate((tail[i], rglist[i][0 : (batch_size - len(tail[i]))])) + for i in range(self.num_fields) + ] + ) + if len(tail[0]) == batch_size: + label = tail[0] + features = tail[1:] + tail = None + yield label, np.stack(features, axis=-1) + else: + pos = 0 + continue + + while (pos + batch_size) <= len(rglist[0]): + label = rglist[0][pos : pos + batch_size] + features = [rglist[j][pos : pos + batch_size] for j in range(1, self.num_fields)] + pos += batch_size + yield label, np.stack(features, axis=-1) + + if pos != len(rglist[0]): + tail = [rglist[i][pos:] for i in range(self.num_fields)] + + +def make_criteo_dataloader(data_path, batch_size, shuffle=True): + """Make a Criteo Parquet DataLoader. + :return: a context manager when exit the returned context manager, the reader will be closed. + """ + files = ["file://" + name for name in glob.glob(f"{data_path}/*.parquet")] + files.sort() + + world_size = flow.env.get_world_size() + batch_size_per_proc = batch_size // world_size + + return PNNDataReader( + files, + batch_size_per_proc, + None, # TODO: iterate over all eval dataset + shuffle_row_groups=shuffle, + shard_seed=2020, + shard_count=world_size, + cur_shard=flow.env.get_rank(), + ) + + +class OneEmbedding(nn.Module): + def __init__( + self, + table_name, + embedding_vec_size, + persistent_path, + table_size_array, + store_type, + cache_memory_budget_mb, + size_factor, + ): + assert table_size_array is not None + vocab_size = sum(table_size_array) + tables = [ + flow.one_embedding.make_table( + flow.one_embedding.make_normal_initializer(mean=0.0, std=1e-4) + ) + for _ in range(len(table_size_array)) + ] + + if store_type == "device_mem": + store_options = flow.one_embedding.make_device_mem_store_options( + persistent_path=persistent_path, capacity=vocab_size, size_factor=size_factor, + ) + elif store_type == "cached_host_mem": + assert cache_memory_budget_mb > 0 + store_options = flow.one_embedding.make_cached_host_mem_store_options( + cache_budget_mb=cache_memory_budget_mb, + persistent_path=persistent_path, + capacity=vocab_size, + size_factor=size_factor, + ) + elif store_type == "cached_ssd": + assert cache_memory_budget_mb > 0 + store_options = flow.one_embedding.make_cached_ssd_store_options( + cache_budget_mb=cache_memory_budget_mb, + persistent_path=persistent_path, + capacity=vocab_size, + size_factor=size_factor, + ) + else: + raise NotImplementedError("not support", store_type) + + super(OneEmbedding, self).__init__() + self.one_embedding = flow.one_embedding.MultiTableEmbedding( + name=table_name, + embedding_dim=embedding_vec_size, + dtype=flow.float, + key_type=flow.int64, + tables=tables, + store_options=store_options, + ) + + def forward(self, ids): + return self.one_embedding.forward(ids) + + +class DNN(nn.Module): + def __init__( + self, + in_features: int, + hidden_units, + out_features, + skip_final_activation=False, + dropout=0.0, + ) -> None: + super(DNN, self).__init__() + + denses = [] + dropout_rates = [dropout] * len(hidden_units) + [0.0] + use_relu = [True] * len(hidden_units) + [not skip_final_activation] + hidden_units = [in_features] + hidden_units + [out_features] + for idx in range(len(hidden_units) - 1): + denses.append(nn.Linear(hidden_units[idx], hidden_units[idx + 1], bias=True)) + if use_relu[idx]: + denses.append(nn.ReLU()) + if dropout_rates[idx] > 0: + denses.append(nn.Dropout(p=dropout_rates[idx])) + self.linear_layers = nn.Sequential(*denses) + + for name, param in self.linear_layers.named_parameters(): + if "weight" in name: + nn.init.xavier_normal_(param) + elif "bias" in name: + param.data.fill_(0.0) + + def forward(self, x: flow.Tensor) -> flow.Tensor: + return self.linear_layers(x) + + +class Interaction(nn.Module): + def __init__( + self, num_embedding_fields, interaction_itself=False, interaction_padding=False, + ): + super(Interaction, self).__init__() + self.interaction_itself = interaction_itself + n_cols = num_embedding_fields + 2 if self.interaction_itself else num_embedding_fields + 1 + output_size = sum(range(n_cols)) + self.output_size = ((output_size + 8 - 1) // 8 * 8) if interaction_padding else output_size + self.output_padding = self.output_size - output_size + + def forward(self, x: flow.Tensor) -> flow.Tensor: + return flow._C.fused_dot_feature_interaction( + [x], + output_concat=None, + self_interaction=self.interaction_itself, + output_padding=self.output_padding, + ) + + +class OutterProductLayer(nn.Module): + def __init__(self, field_size, embedding_size): + super(OutterProductLayer, self).__init__() + num_inputs = field_size + num_pairs = int(num_inputs * (num_inputs - 1) / 2) + embed_size = embedding_size + self.kernel = nn.Parameter(flow.Tensor(embed_size, num_pairs, embed_size)) + nn.init.xavier_uniform_(self.kernel) + + def forward(self, inputs): + embed_list = [field_emb for field_emb in inputs] + row = [] + col = [] + num_inputs = inputs.shape[0] + for i in range(num_inputs - 1): + for j in range(i + 1, num_inputs): + row.append(i) + col.append(j) + p = flow.cat([embed_list[idx] for idx in row], dim=1) # batch num_pairs k + q = flow.cat([embed_list[idx] for idx in col], dim=1) + res = flow.mul(p.unsqueeze(dim=1), self.kernel) + res = flow.sum(res, dim=-1) + res = flow.transpose(res, 2, 1) + res = flow.mul(res, q) + res = flow.sum(res, dim=-1) + return res + + +class PNNModule(nn.Module): + def __init__( + self, + embedding_vec_size=128, + dnn=[1024, 1024, 512, 256], + persistent_path=None, + table_size_array=None, + one_embedding_store_type="cached_host_mem", + cache_memory_budget_mb=8192, + dropout=0.2, + use_inner=True, + use_outter=False, + ): + super(PNNModule, self).__init__() + self.embedding_vec_size = embedding_vec_size + self.embedding_layer = OneEmbedding( + table_name="sparse_embedding", + embedding_vec_size=embedding_vec_size, + persistent_path=persistent_path, + table_size_array=table_size_array, + store_type=one_embedding_store_type, + cache_memory_budget_mb=cache_memory_budget_mb, + size_factor=3, + ) + self.use_inner = use_inner + self.use_outter = use_outter + self.fields = num_sparse_fields + num_dense_fields + self.input_dim = embedding_vec_size * self.fields + if self.use_inner: + self.input_dim += sum(range(self.fields)) + self.inner_product_layer = Interaction(self.fields) + if self.use_outter: + self.input_dim += sum(range(self.fields)) + self.outter_product_layer = OutterProductLayer(self.fields, embedding_vec_size) + self.dnn_layer = DNN( + in_features=self.input_dim, + hidden_units=dnn, + out_features=1, + skip_final_activation=True, + dropout=dropout, + ) + + def forward(self, inputs) -> flow.Tensor: + E = self.embedding_layer(inputs) + if self.use_inner: + I = self.inner_product_layer(E) + if self.use_outter: + O = self.outter_product_layer(E.reshape(self.fields, -1, 1, self.embedding_vec_size)) + + if self.use_inner and self.use_outter: + dense_input = flow.cat([E.flatten(start_dim=1), I, O], dim=1) + elif self.use_inner: + dense_input = flow.cat([E.flatten(start_dim=1), I], dim=1) + elif self.use_outter: + dense_input = flow.cat([E.flatten(start_dim=1), O], dim=1) + else: + dense_input = flow.cat([E.flatten(start_dim=1)], dim=1) + dnn_pred = self.dnn_layer(dense_input) + return dnn_pred + + +def make_pnn_module(args): + model = PNNModule( + embedding_vec_size=args.embedding_vec_size, + dnn=args.dnn, + persistent_path=args.persistent_path, + table_size_array=args.table_size_array, + one_embedding_store_type=args.store_type, + cache_memory_budget_mb=args.cache_memory_budget_mb, + dropout=args.net_dropout, + use_inner=args.use_inner, + use_outter=args.use_outter, + ) + return model + + +class PNNValGraph(flow.nn.Graph): + def __init__(self, pnn_module, amp=False): + super(PNNValGraph, self).__init__() + self.module = pnn_module + if amp: + self.config.enable_amp(True) + + def build(self, features): + predicts = self.module(features.to("cuda")) + return predicts.sigmoid() + + +class PNNTrainGraph(flow.nn.Graph): + def __init__( + self, pnn_module, loss, optimizer, grad_scaler=None, amp=False, lr_scheduler=None, + ): + super(PNNTrainGraph, self).__init__() + self.module = pnn_module + self.loss = loss + # self.max_norm = max_norm + self.add_optimizer(optimizer, lr_sch=lr_scheduler) + self.config.allow_fuse_model_update_ops(True) + self.config.allow_fuse_add_to_output(True) + self.config.allow_fuse_cast_scale(True) + if amp: + self.config.enable_amp(True) + self.set_grad_scaler(grad_scaler) + + def build(self, labels, features): + logits = self.module(features.to("cuda")) + loss = self.loss(logits, labels.to("cuda")) + loss.backward() + return loss.to("cpu") + + +def make_lr_scheduler(args, optimizer): + batches_per_epoch = math.ceil(args.num_train_samples / args.batch_size) + milestones = [ + batches_per_epoch * (i + 1) + for i in range(math.floor(math.log(args.min_lr / args.learning_rate, args.lr_factor))) + ] + multistep_lr = flow.optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, gamma=args.lr_factor, milestones=milestones, + ) + return multistep_lr + + +def get_metrics(logs): + kv = {"auc": 1, "logloss": -1} + monitor_value = 0 + for k, v in kv.items(): + monitor_value += logs.get(k, 0) * v + return monitor_value + + +def early_stop(epoch, monitor_value, best_metric, stopping_steps, patience=2, min_delta=1e-6): + rank = flow.env.get_rank() + stop_training = False + save_best = False + if monitor_value < best_metric + min_delta: + stopping_steps += 1 + if rank == 0: + print("Monitor(max) STOP: {:.6f}!".format(monitor_value)) + else: + stopping_steps = 0 + best_metric = monitor_value + save_best = True + if stopping_steps >= patience: + stop_training = True + if rank == 0: + print(f"Early stopping at epoch={epoch}!") + return stop_training, best_metric, stopping_steps, save_best + + +def train(args): + rank = flow.env.get_rank() + + pnn_module = make_pnn_module(args) + pnn_module.to_global(flow.env.all_device_placement("cuda"), flow.sbp.broadcast) + + def load_model(dir): + if rank == 0: + print(f"Loading model from {dir}") + if os.path.exists(dir): + state_dict = flow.load(dir, global_src_rank=0) + pnn_module.load_state_dict(state_dict, strict=False) + else: + if rank == 0: + print(f"Loading model from {dir} failed: invalid path") + + if args.model_load_dir: + load_model(args.model_load_dir) + + def save_model(subdir): + if not args.model_save_dir: + return + save_path = os.path.join(args.model_save_dir, subdir) + if rank == 0: + print(f"Saving model to {save_path}") + state_dict = pnn_module.state_dict() + flow.save(state_dict, save_path, global_dst_rank=0) + + if args.save_initial_model: + save_model("initial_checkpoint") + + opt = flow.optim.Adam(pnn_module.parameters(), lr=args.learning_rate) + lr_scheduler = make_lr_scheduler(args, opt) + loss = flow.nn.BCEWithLogitsLoss(reduction="mean").to("cuda") + + if args.loss_scale_policy == "static": + grad_scaler = flow.amp.StaticGradScaler(1024) + else: + grad_scaler = flow.amp.GradScaler( + init_scale=1073741824, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, + ) + + eval_graph = PNNValGraph(pnn_module, args.amp) + train_graph = PNNTrainGraph( + pnn_module, loss, opt, grad_scaler, args.amp, lr_scheduler=lr_scheduler + ) + + batches_per_epoch = math.ceil(args.num_train_samples / args.batch_size) + + # will be updated by rank 0 only + best_metric = -np.inf + stopping_steps = 0 + save_best = False + stop_training = False + + cached_eval_batches = prefetch_eval_batches( + f"{args.data_dir}/val", args.batch_size, math.ceil(args.num_val_samples / args.batch_size), + ) + + pnn_module.train() + epoch = 0 + with make_criteo_dataloader(f"{args.data_dir}/train", args.batch_size) as loader: + step, last_step, last_time = -1, 0, time.time() + for step in range(1, args.train_batches + 1): + labels, features = batch_to_global(*next(loader)) + loss = train_graph(labels, features) + if step % args.loss_print_interval == 0: + loss = loss.numpy() + if rank == 0: + latency = (time.time() - last_time) / (step - last_step) + throughput = args.batch_size / latency + last_step, last_time = step, time.time() + strtime = time.strftime("%Y-%m-%d %H:%M:%S") + print( + f"Rank[{rank}], Step {step}, Loss {loss:0.4f}, " + + f"Latency {(latency * 1000):0.3f} ms, Throughput {throughput:0.1f}, {strtime}" + ) + + if step % batches_per_epoch == 0: + epoch += 1 + auc, logloss = eval( + args, + eval_graph, + tag="val", + cur_step=step, + epoch=epoch, + cached_eval_batches=cached_eval_batches, + ) + if args.save_model_after_each_eval: + save_model(f"step_{step}_val_auc_{auc:0.5f}") + + monitor_value = get_metrics(logs={"auc": auc, "logloss": logloss}) + + stop_training, best_metric, stopping_steps, save_best = early_stop( + epoch, + monitor_value, + best_metric=best_metric, + stopping_steps=stopping_steps, + patience=args.patience, + min_delta=args.min_delta, + ) + + if args.save_best_model and save_best: + if rank == 0: + print(f"Save best model: monitor(max): {best_metric:.6f}") + save_model("best_checkpoint") + if not args.disable_early_stop and stop_training: + break + + pnn_module.train() + last_time = time.time() + + load_model(f"{args.model_save_dir}/best_checkpoint") + if rank == 0: + print("================ Test Evaluation ================") + eval(args, eval_graph, tag="test", cur_step=step, epoch=epoch) + + if step % batches_per_epoch != 0: + auc, logloss = eval(args, eval_graph, step) + if args.save_model_after_each_eval: + save_model(f"step_{step}_val_auc_{auc:0.5f}") + + +def np_to_global(np): + t = flow.from_numpy(np) + return t.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) + + +def batch_to_global(np_label, np_features, is_train=True): + labels = np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1) + features = np_to_global(np_features) + return labels, features + + +def prefetch_eval_batches(data_dir, batch_size, num_batches): + cached_eval_batches = [] + with make_criteo_dataloader(data_dir, batch_size, shuffle=False) as loader: + for _ in range(num_batches): + label, features = batch_to_global(*next(loader), is_train=False) + cached_eval_batches.append((label, features)) + return cached_eval_batches + + +def eval(args, eval_graph, tag="val", cur_step=0, epoch=0, cached_eval_batches=None): + if tag == "val": + batches_per_epoch = math.ceil(args.num_val_samples / args.batch_size) + else: + batches_per_epoch = math.ceil(args.num_test_samples / args.batch_size) + eval_graph.module.eval() + labels, preds = [], [] + eval_start_time = time.time() + if cached_eval_batches == None: + with make_criteo_dataloader( + f"{args.data_dir}/{tag}", args.batch_size, shuffle=False + ) as loader: + eval_start_time = time.time() + for i in range(batches_per_epoch): + label, features = batch_to_global(*next(loader), is_train=False) + pred = eval_graph(features) + labels.append(label) + preds.append(pred.to_local()) + else: + for i in range(batches_per_epoch): + label, features = cached_eval_batches[i] + pred = eval_graph(features) + labels.append(label) + preds.append(pred.to_local()) + + labels = ( + np_to_global(np.concatenate(labels, axis=0)).to_global(sbp=flow.sbp.broadcast()).to_local() + ) + preds = ( + flow.cat(preds, dim=0) + .to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) + .to_global(sbp=flow.sbp.broadcast()) + .to_local() + ) + + flow.comm.barrier() + eval_time = time.time() - eval_start_time + + rank = flow.env.get_rank() + + metrics_start_time = time.time() + auc = flow.roc_auc_score(labels, preds).numpy()[0] + logloss = flow._C.binary_cross_entropy_loss(preds, labels, weight=None, reduction="mean") + metrics_time = time.time() - metrics_start_time + + if rank == 0: + host_mem_mb = psutil.Process().memory_info().rss // (1024 * 1024) + stream = os.popen("nvidia-smi --query-gpu=memory.used --format=csv") + device_mem_str = stream.read().split("\n")[rank + 1] + + strtime = time.strftime("%Y-%m-%d %H:%M:%S") + print( + f"Rank[{rank}], Epoch {epoch}, Step {cur_step}, AUC {auc:0.6f}, LogLoss {logloss:0.6f}, " + + f"Eval_time {eval_time:0.2f} s, Metrics_time {metrics_time:0.2f} s, Eval_samples {labels.shape[0]}, " + + f"GPU_Memory {device_mem_str}, Host_Memory {host_mem_mb} MiB, {strtime}" + ) + + return auc, logloss + + +if __name__ == "__main__": + os.system(sys.executable + " -m oneflow --doctor") + flow.boxing.nccl.enable_all_to_all(True) + args = get_args() + train(args) diff --git a/RecommenderSystems/pnn/tools/launch_spark.sh b/RecommenderSystems/pnn/tools/launch_spark.sh new file mode 100644 index 000000000..cb804260f --- /dev/null +++ b/RecommenderSystems/pnn/tools/launch_spark.sh @@ -0,0 +1,5 @@ +export SPARK_LOCAL_DIRS=/tmp/tmp_spark +spark-shell \ + --master "local[*]" \ + --conf spark.driver.maxResultSize=0 \ + --driver-memory 360G diff --git a/RecommenderSystems/pnn/tools/pnn_parquet.scala b/RecommenderSystems/pnn/tools/pnn_parquet.scala new file mode 100644 index 000000000..b23c8704e --- /dev/null +++ b/RecommenderSystems/pnn/tools/pnn_parquet.scala @@ -0,0 +1,35 @@ +import org.apache.spark.sql.functions.udf + +def makePNNDataset(srcDir: String, dstDir:String) = { + val train_csv = s"${srcDir}/train.csv" + val test_csv = s"${srcDir}/test.csv" + val val_csv = s"${srcDir}/valid.csv" + + val make_label = udf((str:String) => str.toFloat) + val label_cols = Seq(make_label($"Label").as("Label")) + + val dense_cols = 1.to(13).map{i=>xxhash64(lit(i), col(s"I$i")).as(s"I${i}")} + + var sparse_cols = 1.to(26).map{i=>xxhash64(lit(i), col(s"C$i")).as(s"C${i}")} + + val cols = label_cols ++ dense_cols ++ sparse_cols + + spark.read.option("header","true").csv(test_csv).select(cols:_*).repartition(32).write.parquet(s"${dstDir}/test") + spark.read.option("header","true").csv(val_csv).select(cols:_*).repartition(32).write.parquet(s"${dstDir}/val") + + spark.read.option("header","true").csv(train_csv).select(cols:_*).orderBy(rand()).repartition(256).write.parquet(s"${dstDir}/train") + + // print the number of samples + val train_samples = spark.read.parquet(s"${dstDir}/train").count() + println(s"train samples = $train_samples") + val val_samples = spark.read.parquet(s"${dstDir}/val").count() + println(s"validation samples = $val_samples") + val test_samples = spark.read.parquet(s"${dstDir}/test").count() + println(s"test samples = $test_samples") + + // print table size array + val df = spark.read.parquet(s"${dstDir}/train", s"${dstDir}/val", s"${dstDir}/test") + println("table size array: ") + println(1.to(13).map{i=>df.select(s"I$i").as[Long].distinct.count}.mkString(",")) + println(1.to(26).map{i=>df.select(s"C$i").as[Long].distinct.count}.mkString(",")) +} diff --git a/RecommenderSystems/pnn/tools/split_criteo_kaggle.py b/RecommenderSystems/pnn/tools/split_criteo_kaggle.py new file mode 100644 index 000000000..62e2918ba --- /dev/null +++ b/RecommenderSystems/pnn/tools/split_criteo_kaggle.py @@ -0,0 +1,55 @@ +import numpy as np +import pandas as pd +import argparse +from sklearn.model_selection import StratifiedKFold + +RANDOM_SEED = 2018 # Fix seed for reproduction + + +def split_train_val_test(input_dir, output_dir): + num_dense_fields = 13 + num_sparse_fields = 26 + + fields = ["Label"] + fields += [f"I{i+1}" for i in range(num_dense_fields)] + fields += [f"C{i+1}" for i in range(num_sparse_fields)] + + ddf = pd.read_csv( + f"{input_dir}/train.txt", + sep="\t", + header=None, + names=fields, + encoding="utf-8", + dtype=object, + ) + X = ddf.values + y = ddf["Label"].map(lambda x: float(x)).values + print(f"{len(X)} samples in total") + + folds = StratifiedKFold(n_splits=10, shuffle=True, random_state=RANDOM_SEED) + + fold_indexes = [valid_idx for _, valid_idx in folds.split(X, y)] + test_index = fold_indexes[0] + valid_index = fold_indexes[1] + train_index = np.concatenate(fold_indexes[2:]) + + ddf.loc[test_index, :].to_csv(f"{output_dir}/test.csv", index=False, encoding="utf-8") + ddf.loc[valid_index, :].to_csv(f"{output_dir}/valid.csv", index=False, encoding="utf-8") + ddf.loc[train_index, :].to_csv(f"{output_dir}/train.csv", index=False, encoding="utf-8") + + print("Train lines:", len(train_index)) + print("Validation lines:", len(valid_index)) + print("Test lines:", len(test_index)) + print("Postive ratio:", np.sum(y) / len(y)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", type=str, required=True, help="Path to downloaded criteo kaggle dataset", + ) + parser.add_argument( + "--output_dir", type=str, required=True, help="Path to splitted criteo kaggle dataset", + ) + args = parser.parse_args() + split_train_val_test(args.input_dir, args.output_dir) diff --git a/RecommenderSystems/pnn/train_pnn.sh b/RecommenderSystems/pnn/train_pnn.sh new file mode 100644 index 000000000..2527b66e0 --- /dev/null +++ b/RecommenderSystems/pnn/train_pnn.sh @@ -0,0 +1,29 @@ +#!/bin/bash +DEVICE_NUM_PER_NODE=1 +DATA_DIR=/path/to/deepfm_parquet +PERSISTENT_PATH=/path/to/persistent +MODEL_SAVE_DIR=/path/to/model/save/dir + +python3 -m oneflow.distributed.launch \ + --nproc_per_node $DEVICE_NUM_PER_NODE \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 127.0.0.1 \ + pnn_train_eval.py \ + --data_dir $DATA_DIR \ + --persistent_path $PERSISTENT_PATH \ + --table_size_array "649,9364,14746,490,476707,11618,4142,1373,7275,13,169,407,1376,1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572" \ + --store_type 'cached_host_mem' \ + --cache_memory_budget_mb 1024 \ + --batch_size 10000 \ + --train_batches 75000 \ + --loss_print_interval 100 \ + --dnn "1000,1000" \ + --net_dropout 0.2 \ + --learning_rate 0.001 \ + --embedding_vec_size 16 \ + --num_train_samples 36672493 \ + --num_val_samples 4584062 \ + --num_test_samples 4584062 \ + --model_save_dir $MODEL_SAVE_DIR \ + --save_best_model