Skip to content

Commit

Permalink
Support for categorical-set features in the TF-DF signature of the YD…
Browse files Browse the repository at this point in the history
…F JS API.

PiperOrigin-RevId: 502880133
  • Loading branch information
achoum authored and copybara-github committed Jan 18, 2023
1 parent a3668a0 commit c7e86e0
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 47 deletions.
12 changes: 8 additions & 4 deletions yggdrasil_decision_forests/port/javascript/externs.js
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
* @typedef {{
* size: function(): number,
* get: function(number) : (string|boolean|number|!Object),
* push_back: function(?),
* }}
*/
let CCVector;
Expand Down Expand Up @@ -56,6 +57,7 @@ let InputFeature;
* @typedef {{
* size: function(): number,
* get: function(number) : !InputFeature,
* push_back: function(!InputFeature),
* }}
*/
let InternalInputFeatures;
Expand All @@ -66,6 +68,7 @@ let InternalInputFeatures;
* @typedef {{
* size: function(): number,
* get: function(number) : number,
* push_back: function(number),
* }}
*/
let InternalPredictions;
Expand Down Expand Up @@ -106,9 +109,9 @@ let TFDFOutputPrediction;
* numericalFeatures: !Array<!Array<number>>,
* booleanFeatures: !Array<!Array<number>>,
* categoricalIntFeatures: !Array<!Array<number>>,
* categoricalSetIntFeaturesValues: !Array<!Array<number>>,
* categoricalSetIntFeaturesRowSplitsDim1: !Array<!Array<number>>,
* categoricalSetIntFeaturesRowSplitsDim2: !Array<!Array<number>>,
* categoricalSetIntFeaturesValues: !Array<number>,
* categoricalSetIntFeaturesRowSplitsDim1: !Array<number>,
* categoricalSetIntFeaturesRowSplitsDim2: !Array<number>,
* denseOutputDim: number,
* }}
*/
Expand All @@ -125,7 +128,8 @@ let TFDFInput;
* setBoolean: function(number,number,number),
* setCategoricalInt: function(number,number,number),
* setCategoricalString: function(number,number,string),
* setCategoricalSetString: function(number,number,!Array<string>),
* setCategoricalSetString: function(number,number,!CCVector),
* setCategoricalSetInt: function(number,number,!CCVector),
* getInputFeatures: function(): !InternalInputFeatures,
* getProtoInputFeatures: function(): !InternalInputFeatures,
* delete: function(),
Expand Down
28 changes: 28 additions & 0 deletions yggdrasil_decision_forests/port/javascript/inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,17 @@ class Model {
engine_->features());
}

// Sets the value of a categorical set feature.
void SetCategoricalSetInt(int example_idx, int feature_id,
std::vector<int> value) {
if (example_idx >= num_examples_) {
LOG(WARNING) << "example_idx should be less than the number of examples";
return;
}
examples_->SetCategoricalSet(example_idx, {feature_id}, value,
engine_->features());
}

// Runs the model on the previously set features.
std::vector<float> Predict() {
if (num_examples_ == -1) {
Expand Down Expand Up @@ -311,6 +322,18 @@ std::shared_ptr<Model> LoadModel(std::string path,
created_tfdf_signature);
}

std::vector<std::string> CreateVectorString(size_t reserved) {
std::vector<std::string> v;
v.reserve(reserved);
return v;
}

std::vector<int> CreateVectorInt(size_t reserved) {
std::vector<int> v;
v.reserve(reserved);
return v;
}

// Expose some of the class/functions to JS.
//
// Keep this list in sync with the corresponding @typedef in wrapper.js.
Expand All @@ -325,6 +348,7 @@ EMSCRIPTEN_BINDINGS(my_module) {
.function("setCategoricalInt", &Model::SetCategoricalInt)
.function("setCategoricalString", &Model::SetCategoricalString)
.function("setCategoricalSetString", &Model::SetCategoricalSetString)
.function("setCategoricalSetInt", &Model::SetCategoricalSetInt)
.function("getInputFeatures", &Model::GetInputFeatures)
.function("getProtoInputFeatures", &Model::GetProtoInputFeatures);

Expand All @@ -343,6 +367,10 @@ EMSCRIPTEN_BINDINGS(my_module) {

emscripten::register_vector<InputFeature>("vector<InputFeature>");
emscripten::register_vector<float>("vector<float>");
emscripten::register_vector<int>("vector<int>");
emscripten::register_vector<std::vector<float>>("vector<vector<float>>");
emscripten::register_vector<std::string>("vector<string>");

emscripten::function("CreateVectorString", &CreateVectorString);
emscripten::function("CreateVectorInt", &CreateVectorInt);
}
14 changes: 14 additions & 0 deletions yggdrasil_decision_forests/port/javascript/karma.conf.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,20 @@ module.exports = function(config) {
nocache: false,
included: false,
});
config.files.push({
pattern: basePath + 'test_data/model_2.zip',
watched: false,
served: true,
nocache: false,
included: false,
});
config.files.push({
pattern: basePath + 'test_data/model_small_sst.zip',
watched: false,
served: true,
nocache: false,
included: false,
});
config.files.push({
pattern: 'third_party/javascript/node_modules/jszip/jszip.min.js',
watched: false,
Expand Down
9 changes: 9 additions & 0 deletions yggdrasil_decision_forests/port/javascript/test_data/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package(
default_visibility = ["//yggdrasil_decision_forests/port/javascript:users"],
licenses = ["notice"],
)

filegroup(
name = "models",
srcs = glob(["*.zip"]),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- model_2.zip: A model trained on a 8 hand crafted examples.
- model_sst_small: A random forest with 10 trees (and some other limits)
trained on the SST2 dataset.
Binary file not shown.
Binary file not shown.
125 changes: 98 additions & 27 deletions yggdrasil_decision_forests/port/javascript/wrapper.js
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,32 @@ function ccVectorToJSVector(src) {
.map((unused, index) => src.get(index));
}

/**
* Converts a JS array into a std::vector<std::string> (C++).
* @param {!Array} src JS Vector.
* @return {!CCVector} CC Vector.
*/
function jsStrVectorToCCStrVector(src) {
const vector = Module.CreateVectorString(src.length);
for (const value of src) {
vector.push_back(value);
}
return vector;
}

/**
* Converts a JS array into a std::vector<int> (C++).
* @param {!Array} src JS Vector.
* @return {!CCVector} CC Vector.
*/
function jsStrVectorToCCIntVector(src) {
const vector = Module.CreateVectorInt(src.length);
for (const value of src) {
vector.push_back(value);
}
return vector;
}

/**
* Converts a std::vector<std::vector<T>> (C++) into a JS array or array.
* @param {!CCVectorVector} src CC Matrix.
Expand Down Expand Up @@ -178,6 +204,12 @@ class Model {
*/
this.categoricalIntFeaturesIndex = null;

/**
* Index of the categorical-set input features for the TF-DF signature.
* @private @type {?Array<number>}
*/
this.categoricalSetIntFeaturesIndex = null;

if (this.createdTFDFSignature) {
this.createdTFDFSignature_();
}
Expand All @@ -198,6 +230,8 @@ class Model {
indexTFDFFeatures(protoInputFeatures, this.inputFeatures, ['BOOLEAN']);
this.categoricalIntFeaturesIndex = indexTFDFFeatures(
protoInputFeatures, this.inputFeatures, ['CATEGORICAL']);
this.categoricalSetIntFeaturesIndex = indexTFDFFeatures(
protoInputFeatures, this.inputFeatures, ['CATEGORICAL_SET']);
}


Expand Down Expand Up @@ -235,25 +269,25 @@ class Model {
*/
predict(examples) {
if (typeof examples !== 'object') {
throw Error('argument should be an array or an object');
throw new Error('argument should be an array or an object');
}

// Detect the number of examples and ensure that all the fields (i.e.
// features) are arrays with the same number of items.
let numExamples = undefined;
for (const values of Object.values(examples)) {
if (!Array.isArray(values)) {
throw Error('features should be arrays');
throw new Error('features should be arrays');
}
if (numExamples === undefined) {
numExamples = values.length;
} else if (numExamples !== values.length) {
throw Error('features have a different number of values');
throw new Error('features have a different number of values');
}
}
if (numExamples === undefined) {
// The example does not contain any features.
throw Error('not features');
throw new Error('not features');
}

// Fill the examples
Expand Down Expand Up @@ -281,10 +315,11 @@ class Model {
for (const [exampleIdx, value] of values.entries()) {
if (value === null) continue;
this.internalModel.setCategoricalSetString(
exampleIdx, featureDef.internalIdx, value);
exampleIdx, featureDef.internalIdx,
jsStrVectorToCCStrVector(value));
}
} else {
throw Error(`Non supported feature type ${featureDef}`);
throw new Error(`Non supported feature type ${featureDef}`);
}
}

Expand All @@ -309,53 +344,61 @@ class Model {
* @return {!TFDFOutputPrediction} Predictions of the model.
*/
predictTFDFSignature(inputs) {
// TODO: Add support for categorical-set features.

if (!this.createdTFDFSignature) {
throw Error('Model not loaded with options.createdTFDFSignature=true');
}

if (inputs.categoricalSetIntFeaturesRowSplitsDim1.length != 1 ||
inputs.categoricalSetIntFeaturesRowSplitsDim1[0] != 0) {
throw Error(
'Categorical-set features are currently not supported with this ' +
'interface (predictTensorFlowDecisionForestSignature). Use ' +
'"predict" instead.');
throw new Error(
'Model not loaded with options.createdTFDFSignature=true');
}

// Detect the number of examples.
//
// For each type of given feature, extract the number of examples. Ensure
// that the feature shapes are consistant.
let numExamples = 0;
if (inputs.numericalFeatures.length != 0) {
if (numExamples != 0 && numExamples != inputs.numericalFeatures.length) {
throw Error('features have a different number of values');
throw new Error('features have a different number of values');
}
if (this.numericalFeaturesIndex.length !=
inputs.numericalFeatures[0].length) {
throw Error('Unexpected numerical input feature shape');
throw new Error('Unexpected numerical input feature shape');
}
numExamples = inputs.numericalFeatures.length;
}
if (inputs.booleanFeatures.length != 0) {
if (numExamples != 0 && numExamples != inputs.booleanFeatures.length) {
throw Error('features have a different number of values');
throw new Error('features have a different number of values');
}
if (this.booleanFeaturesIndex.length !=
inputs.booleanFeatures[0].length) {
throw Error('Unexpected boolean input feature shape');
throw new Error('Unexpected boolean input feature shape');
}
numExamples = inputs.booleanFeatures.length;
}
if (inputs.categoricalIntFeatures.length != 0) {
if (numExamples != 0 &&
numExamples != inputs.categoricalIntFeatures.length) {
throw Error('features have a different number of values');
throw new Error('features have a different number of values');
}
if (this.categoricalIntFeaturesIndex.length !=
inputs.categoricalIntFeatures[0].length) {
throw Error('Unexpected categorical int input feature shape');
throw new Error('Unexpected categorical int input feature shape');
}
numExamples = inputs.categoricalIntFeatures.length;
}
if (inputs.categoricalSetIntFeaturesRowSplitsDim2.length > 1) {
if (this.categoricalSetIntFeaturesIndex.length == null ||
this.categoricalSetIntFeaturesIndex.length <= 0) {
throw new Error('Invalid categoricalSetIntFeaturesIndex');
}
const detectedNumExamples =
(inputs.categoricalSetIntFeaturesRowSplitsDim2
[inputs.categoricalSetIntFeaturesRowSplitsDim2.length - 1] /
this.categoricalSetIntFeaturesIndex.length);
if (numExamples != 0 && numExamples != detectedNumExamples) {
throw new Error('Invalid categorical-set feature shape');
}
numExamples = detectedNumExamples;
}

// Allocate the examples
this.internalModel.newBatchOfExamples(numExamples);
Expand Down Expand Up @@ -409,6 +452,35 @@ class Model {
}
}

for (let localIdx = 0;
localIdx < this.categoricalSetIntFeaturesIndex.length; localIdx++) {
const internIdx = this.categoricalSetIntFeaturesIndex[localIdx];
if (internIdx == -1) {
continue;
}
for (let exampleIdx = 0; exampleIdx < numExamples; exampleIdx++) {
const d1Cell =
exampleIdx * this.categoricalSetIntFeaturesIndex.length + localIdx;
const beginIdx = inputs.categoricalSetIntFeaturesRowSplitsDim1[d1Cell];
const endIdx =
inputs.categoricalSetIntFeaturesRowSplitsDim1[d1Cell + 1];
const ccValues = Module.CreateVectorInt(endIdx - beginIdx);

if (endIdx > beginIdx &&
inputs.categoricalSetIntFeaturesValues[beginIdx] == -1) {
// This is a missing value.
continue;
}

for (let itemIdx = beginIdx; itemIdx < endIdx; itemIdx++) {
ccValues.push_back(inputs.categoricalSetIntFeaturesValues[itemIdx]);
}
this.internalModel.setCategoricalSetInt(
exampleIdx, internIdx, ccValues);
}
}


// Generate predictions.
const rawPredictions =
this.internalModel.predictTFDFSignature(inputs.denseOutputDim);
Expand Down Expand Up @@ -477,10 +549,9 @@ Module['loadModelFromZipBlob'] =

zippedModel.forEach((filename, file) => {
promiseUncompressed.push(
file.async('blob').then((data) => blobToArrayBuffer(data))
.then((data) => {
file.async('blob').then((data) => blobToArrayBuffer(data)).then((data) => {
if (filename.endsWith('/')) {
throw Error(
throw new Error(
'The model zipfile is expected to be a flat zip file, but it contains a sub-directory. If zipping the model manually with the `zip` tool, make sure to use the `-j` option.');
}
Module.FS.writeFile(
Expand All @@ -504,7 +575,7 @@ Module['loadModelFromZipBlob'] =
Module.FS.rmdir(modelPath);

if (modelWasm == null) {
throw Error('Cannot parse model');
throw new Error('Cannot parse model');
}

return new Model(modelWasm, createdTFDFSignature);
Expand Down
Loading

0 comments on commit c7e86e0

Please sign in to comment.