From fe4089b0480d45234f730ef40574e65c397f1218 Mon Sep 17 00:00:00 2001 From: Fin Date: Wed, 18 Dec 2024 15:31:54 +0800 Subject: [PATCH] Support batch push by accepting multiple keys --- main.go | 7 ++++ route_push.go | 100 +++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 97 insertions(+), 10 deletions(-) diff --git a/main.go b/main.go index 229efd59..5dde5ac2 100644 --- a/main.go +++ b/main.go @@ -111,6 +111,13 @@ func main() { EnvVars: []string{"BARK_SERVER_PROXY_HEADER"}, Value: "", }, + &cli.IntFlag{ + Name: "max-batch-push-count", + Usage: "Maximum number of batch pushes allowed, -1 means no limit", + EnvVars: []string{"BARK_SERVER_MAX_BATCH_PUSH_COUNT"}, + Value: -1, + Action: func(ctx *cli.Context, v int) error { SetMaxBatchPushCount(v); return nil }, + }, &cli.IntFlag{ Name: "max-apns-client-count", Usage: "Maximum number of APNs client connections", diff --git a/route_push.go b/route_push.go index 300e0747..713d4b66 100644 --- a/route_push.go +++ b/route_push.go @@ -1,8 +1,10 @@ package main import ( + "fmt" "net/url" "strings" + "sync" "github.com/gofiber/fiber/v2/utils" @@ -11,6 +13,9 @@ import ( "github.com/gofiber/fiber/v2" ) +// Maximum number of batch pushes allowed, -1 means no limit +var maxBatchPushCount = -1 + func init() { // V2 API registerRoute("push", func(router fiber.Router) { @@ -33,6 +38,11 @@ func init() { }) } +// Set the maximum number of batch pushes allowed +func SetMaxBatchPushCount(count int) { + maxBatchPushCount = count +} + func routeDoPush(c *fiber.Ctx) error { // Get content-type contentType := utils.ToLower(utils.UnsafeString(c.Request().Header.ContentType())) @@ -61,7 +71,12 @@ func routeDoPush(c *fiber.Ctx) error { } } - return push(c, params) + code, err := push(c, params) + if err != nil { + return c.Status(code).JSON(failed(code, err.Error())) + } else { + return c.JSON(success()) + } } func routeDoPushV2(c *fiber.Ctx) error { @@ -74,10 +89,75 @@ func routeDoPushV2(c *fiber.Ctx) error { c.Request().URI().QueryArgs().VisitAll(func(key, value []byte) { params[strings.ToLower(string(key))] = string(value) }) - return push(c, params) + + var deviceKeys []string + // Get the device_keys array from params + if keys, ok := params["device_keys"]; ok { + switch keys := keys.(type) { + case string: + deviceKeys = strings.Split(keys, ",") + case []interface{}: + for _, key := range keys { + deviceKeys = append(deviceKeys, fmt.Sprint(key)) + } + default: + return c.Status(400).JSON(failed(400, "invalid type for device_keys")) + } + delete(params, "device_keys") + } + + count := len(deviceKeys) + + if count == 0 { + // Single push + code, err := push(c, params) + if err != nil { + return c.Status(code).JSON(failed(code, err.Error())) + } else { + return c.JSON(success()) + } + } else { + // Batch push + if count > maxBatchPushCount && maxBatchPushCount != -1 { + return c.Status(400).JSON(failed(400, "batch push count exceeds the maximum limit: %d", maxBatchPushCount)) + } + + var wg sync.WaitGroup + result := make([]map[string]interface{}, count) + var mu sync.Mutex + + for i := 0; i < count; i++ { + // Copy params + newParams := make(map[string]interface{}) + for k, v := range params { + newParams[k] = v + } + newParams["device_key"] = deviceKeys[i] + + wg.Add(1) + go func(i int, newParams map[string]interface{}) { + defer wg.Done() + + // Push + code, err := push(c, newParams) + + // Save result + mu.Lock() + result[i] = make(map[string]interface{}) + if err != nil { + result[i]["message"] = err.Error() + } + result[i]["code"] = code + result[i]["device_key"] = deviceKeys[i] + mu.Unlock() + }(i, newParams) + } + wg.Wait() + return c.JSON(data(result)) + } } -func push(c *fiber.Ctx, params map[string]interface{}) error { +func push(c *fiber.Ctx, params map[string]interface{}) (int, error) { // default value msg := apns.PushMessage{ Body: "", @@ -123,27 +203,27 @@ func push(c *fiber.Ctx, params map[string]interface{}) error { if subtitle := c.Params("subtitle"); subtitle != "" { str, err := url.QueryUnescape(subtitle) if err != nil { - return err + return 500, err } msg.Subtitle = str } if title := c.Params("title"); title != "" { str, err := url.QueryUnescape(title) if err != nil { - return err + return 500, err } msg.Title = str } if body := c.Params("body"); body != "" { str, err := url.QueryUnescape(body) if err != nil { - return err + return 500, err } msg.Body = str } if msg.DeviceKey == "" { - return c.Status(400).JSON(failed(400, "device key is empty")) + return 400, fmt.Errorf("device key is empty") } if msg.Body == "" && msg.Title == "" && msg.Subtitle == "" { @@ -152,13 +232,13 @@ func push(c *fiber.Ctx, params map[string]interface{}) error { deviceToken, err := db.DeviceTokenByKey(msg.DeviceKey) if err != nil { - return c.Status(400).JSON(failed(400, "failed to get device token: %v", err)) + return 400, fmt.Errorf("failed to get device token: %v", err) } msg.DeviceToken = deviceToken err = apns.Push(&msg) if err != nil { - return c.Status(500).JSON(failed(500, "push failed: %v", err)) + return 500, fmt.Errorf("push failed: %v", err) } - return c.JSON(success()) + return 200, nil }