Skip to content

Commit

Permalink
update error handling to throw exception when post processing functio…
Browse files Browse the repository at this point in the history
…n recieve empty result from a model.
  • Loading branch information
tkykenmt committed Dec 15, 2024
1 parent c2a40c1 commit 91abc83
Showing 1 changed file with 114 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,38 @@ result = predictor.predict(data={
]
})

print(json.dumps(sorted(result, key=lambda x: x['index']), indent=2))
print(json.dumps(result, indent=2))
```

The reranking results are as follows:
The reranking result is ordering by the highest score first:
```
[
{
"index": 2,
"score": 0.92879725
},
{
"index": 0,
"score": 0.013636836
},
{
"index": 1,
"score": 0.000593021
},
{
"index": 3,
"score": 0.00012148176
}
]
```

You can sort the result by index number.

```python
print(json.dumps(result, indent=2))
```

The results are as follows:

```
[
Expand Down Expand Up @@ -121,9 +149,46 @@ POST /_plugins/_ml/connectors/_create
"headers": {
"content-type": "application/json"
},
"pre_process_function": """
def query_text = params.query_text;
def text_docs = params.text_docs;
def textDocsBuilder = new StringBuilder('[');
for (int i=0; i<text_docs.length; i++) {
textDocsBuilder.append('\"');
textDocsBuilder.append(text_docs[i]);
textDocsBuilder.append('\"');
if (i<text_docs.length - 1) {
textDocsBuilder.append(',');
}
}
textDocsBuilder.append(']');
def parameters = '{ \"query\": \"' + query_text + '\", \"texts\": ' + textDocsBuilder.toString() + ' }';
return '{\"parameters\": ' + parameters + '}';
""",
"request_body": "{ \"query\": \"${parameters.query}\", \"texts\": ${parameters.texts} }",
"pre_process_function": "\n def query_text = params.query_text;\n def text_docs = params.text_docs;\n def textDocsBuilder = new StringBuilder('[');\n for (int i=0; i<text_docs.length; i++) {\n textDocsBuilder.append('\"');\n textDocsBuilder.append(text_docs[i]);\n textDocsBuilder.append('\"');\n if (i<text_docs.length - 1) {\n textDocsBuilder.append(',');\n }\n }\n textDocsBuilder.append(']');\n def parameters = '{ \"query\": \"' + query_text + '\", \"texts\": ' + textDocsBuilder.toString() + ' }';\n return '{\"parameters\": ' + parameters + '}';\n",
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n def sorted_outputs = outputs;\n for (int i=0; i<outputs.length; i++) {\n def idx = new BigDecimal(outputs[i].index.toString()).intValue();\n sorted_outputs[idx] = outputs[i];\n }\n def resultBuilder = new StringBuilder('[');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
"post_process_function": """
if (params.result == null || params.result.length > 0) {
throw new IllegalArgumentException("Post process function input is empty.");
}
def outputs = params.result;
def scores = new Double[outputs.length];
for (int i=0; i<outputs.length; i++) {
def index = new BigDecimal(outputs[i].index.toString()).intValue();
scores[index] = outputs[i].score;
}
def resultBuilder = new StringBuilder('[');
for (int i=0; i<scores.length; i++) {
resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');
resultBuilder.append('\"data\": [');
resultBuilder.append(scores[i]);
resultBuilder.append(']}');
if (i<outputs.length - 1) {
resultBuilder.append(',');
}
}
resultBuilder.append(']');
return resultBuilder.toString();
"""
}
]
}
Expand Down Expand Up @@ -152,9 +217,46 @@ POST /_plugins/_ml/connectors/_create
"headers": {
"content-type": "application/json"
},
"pre_process_function": """
def query_text = params.query_text;
def text_docs = params.text_docs;
def textDocsBuilder = new StringBuilder('[');
for (int i=0; i<text_docs.length; i++) {
textDocsBuilder.append('\"');
textDocsBuilder.append(text_docs[i]);
textDocsBuilder.append('\"');
if (i<text_docs.length - 1) {
textDocsBuilder.append(',');
}
}
textDocsBuilder.append(']');
def parameters = '{ \"query\": \"' + query_text + '\", \"texts\": ' + textDocsBuilder.toString() + ' }';
return '{\"parameters\": ' + parameters + '}';
""",
"request_body": "{ \"query\": \"${parameters.query}\", \"texts\": ${parameters.texts} }",
"pre_process_function": "\n def query_text = params.query_text;\n def text_docs = params.text_docs;\n def textDocsBuilder = new StringBuilder('[');\n for (int i=0; i<text_docs.length; i++) {\n textDocsBuilder.append('\"');\n textDocsBuilder.append(text_docs[i]);\n textDocsBuilder.append('\"');\n if (i<text_docs.length - 1) {\n textDocsBuilder.append(',');\n }\n }\n textDocsBuilder.append(']');\n def parameters = '{ \"query\": \"' + query_text + '\", \"texts\": ' + textDocsBuilder.toString() + ' }';\n return '{\"parameters\": ' + parameters + '}';\n",
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n def sorted_outputs = outputs;\n for (int i=0; i<outputs.length; i++) {\n def idx = new BigDecimal(outputs[i].index.toString()).intValue();\n sorted_outputs[idx] = outputs[i];\n }\n def resultBuilder = new StringBuilder('[');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
"post_process_function": """
if (params.result == null || params.result.length > 0) {
throw new IllegalArgumentException("Post process function input is empty.");
}
def outputs = params.result;
def scores = new Double[outputs.length];
for (int i=0; i<outputs.length; i++) {
def index = new BigDecimal(outputs[i].index.toString()).intValue();
scores[index] = outputs[i].score;
}
def resultBuilder = new StringBuilder('[');
for (int i=0; i<scores.length; i++) {
resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');
resultBuilder.append('\"data\": [');
resultBuilder.append(scores[i]);
resultBuilder.append(']}');
if (i<outputs.length - 1) {
resultBuilder.append(',');
}
}
resultBuilder.append(']');
return resultBuilder.toString();
"""
}
]
}
Expand Down Expand Up @@ -188,7 +290,7 @@ POST _plugins/_ml/models/your_model_id/_predict
}
```

Each item in the `inputs` array comprises a `query_text` and a `text_docs` string, separated by a ` . `
Each item in the array comprises a `query_text` and a `text_docs` string, separated by a ` . `

Alternatively, you can test the model as follows:
```json
Expand All @@ -209,6 +311,10 @@ The connector `pre_process_function` transforms the input into the format requir
By default, the SageMaker model output has the following format:
```json
[
{
"index": 2,
"score": 0.92879725
},
{
"index": 0,
"score": 0.013636836
Expand All @@ -217,18 +323,14 @@ By default, the SageMaker model output has the following format:
"index": 1,
"score": 0.000593021
},
{
"index": 2,
"score": 0.92879725
},
{
"index": 3,
"score": 0.00012148176
}
]
```

The connector `post_process_function` transforms the model's output into a format that the [Reranker processor](https://opensearch.org/docs/latest/search-plugins/search-pipelines/rerank-processor/) can interpret. This adapted format is as follows:
The connector `post_process_function` transforms the model's output into a format that the [Reranker processor](https://opensearch.org/docs/latest/search-plugins/search-pipelines/rerank-processor/) can interpretm, and order result by index. This adapted format is as follows:
```json
{
"inference_results": [
Expand Down

0 comments on commit 91abc83

Please sign in to comment.