Skip to content

Commit

Permalink
feat: add batch request support (#99)
Browse files Browse the repository at this point in the history
 Add batch support
  • Loading branch information
tarassh authored May 4, 2023
1 parent 1fbbed8 commit f007863
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 20 deletions.
58 changes: 48 additions & 10 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp
cb(w)
}

var req request
// We read the entire request upfront in a buffer to be able to tell if the
// client sent more than maxRequestSize and report it back as an explicit error,
// instead of just silently truncating it and reporting a more vague parsing
Expand All @@ -205,11 +204,11 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp
if err != nil {
// ReadFrom will discard EOF so any error here is unexpected and should
// be reported.
rpcError(wf, &req, rpcParseError, xerrors.Errorf("reading request: %w", err))
rpcError(wf, nil, rpcParseError, xerrors.Errorf("reading request: %w", err))
return
}
if reqSize > s.maxRequestSize {
rpcError(wf, &req, rpcParseError,
rpcError(wf, nil, rpcParseError,
// rpcParseError is the closest we have from the standard errors defined
// in [jsonrpc spec](https://www.jsonrpc.org/specification#error_object)
// to report the maximum limit.
Expand All @@ -218,17 +217,56 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp
return
}

if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.Errorf("unmarshaling request: %w", err))
return
}
// Trim spaces to avoid issues with batch request detection.
bufferedRequest = bytes.NewBuffer(bytes.TrimSpace(bufferedRequest.Bytes()))
reqSize = int64(bufferedRequest.Len())

if req.ID, err = normalizeID(req.ID); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
if reqSize == 0 {
rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request"))
return
}

s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)
if bufferedRequest.Bytes()[0] == '[' && bufferedRequest.Bytes()[reqSize-1] == ']' {
var reqs []request

if err := json.NewDecoder(bufferedRequest).Decode(&reqs); err != nil {
rpcError(wf, nil, rpcParseError, xerrors.New("Parse error"))
return
}

if len(reqs) == 0 {
rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request"))
return
}

w.Write([]byte("["))
for idx, req := range reqs {
if req.ID, err = normalizeID(req.ID); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
return
}

s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)

if idx != len(reqs)-1 {
w.Write([]byte(","))
}
}
w.Write([]byte("]"))
} else {
var req request
if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.New("Parse error"))
return
}

if req.ID, err = normalizeID(req.ID); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
return
}

s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)
}
}

func doCall(methodName string, f reflect.Value, params []reflect.Value) (out []reflect.Value, err error) {
Expand Down
42 changes: 35 additions & 7 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,22 @@ func TestRawRequests(t *testing.T) {
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()

tc := func(req, resp string, n int32) func(t *testing.T) {
removeSpaces := func(jsonStr string) (string, error) {
var jsonObj interface{}
err := json.Unmarshal([]byte(jsonStr), &jsonObj)
if err != nil {
return "", err
}

compactJSONBytes, err := json.Marshal(jsonObj)
if err != nil {
return "", err
}

return string(compactJSONBytes), nil
}

tc := func(req, resp string, n int32, statusCode int) func(t *testing.T) {
return func(t *testing.T) {
rpcHandler.n = 0

Expand All @@ -100,16 +115,29 @@ func TestRawRequests(t *testing.T) {
b, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)

assert.Equal(t, resp, strings.TrimSpace(string(b)))
expectedResp, err := removeSpaces(resp)
require.NoError(t, err)

responseBody, err := removeSpaces(string(b))
require.NoError(t, err)

assert.Equal(t, expectedResp, responseBody)
require.Equal(t, n, rpcHandler.n)
require.Equal(t, statusCode, res.StatusCode)
}
}

t.Run("inc", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": [], "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1))
t.Run("inc-null", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": null, "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1))
t.Run("inc-noparam", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "id": 2}`, `{"jsonrpc":"2.0","id":2}`, 1))
t.Run("add", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [10], "id": 4}`, `{"jsonrpc":"2.0","id":4}`, 10))

t.Run("inc", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": [], "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1, 200))
t.Run("inc-null", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": null, "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1, 200))
t.Run("inc-noparam", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "id": 2}`, `{"jsonrpc":"2.0","id":2}`, 1, 200))
t.Run("add", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [10], "id": 4}`, `{"jsonrpc":"2.0","id":4}`, 10, 200))
// Batch requests
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 5}`, `{"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"Parse error"}}`, 0, 500))
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 6}]`, `[{"jsonrpc":"2.0","id":6}]`, 123, 200))
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 7},{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [-122], "id": 8}]`, `[{"jsonrpc":"2.0","id":7},{"jsonrpc":"2.0","id":8}]`, 1, 200))
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 9},{"jsonrpc": "2.0", "params": [-122], "id": 10}]`, `[{"jsonrpc":"2.0","id":9},{"error":{"code":-32601,"message":"method '' not found"},"id":10,"jsonrpc":"2.0"}]`, 123, 200))
t.Run("add", tc(` [{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [-1], "id": 11}] `, `[{"jsonrpc":"2.0","id":11}]`, -1, 200))
t.Run("add", tc(``, `{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"Invalid request"}}`, 0, 400))
}

func TestReconnection(t *testing.T) {
Expand Down
11 changes: 8 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

const (
rpcParseError = -32700
rpcInvalidRequest = -32600
rpcMethodNotFound = -32601
rpcInvalidParams = -32602
)
Expand Down Expand Up @@ -107,13 +108,17 @@ func rpcError(wf func(func(io.Writer)), req *request, code ErrorCode, err error)
log.Errorf("RPC Error: %s", err)
wf(func(w io.Writer) {
if hw, ok := w.(http.ResponseWriter); ok {
hw.WriteHeader(500)
if code == rpcInvalidRequest {
hw.WriteHeader(400)
} else {
hw.WriteHeader(500)
}
}

log.Warnf("rpc error: %s", err)

if req.ID == nil { // notification
return
if req == nil {
req = &request{}
}

resp := response{
Expand Down

0 comments on commit f007863

Please sign in to comment.