Skip to content

Commit

Permalink
Support batch push by accepting multiple keys
Browse files Browse the repository at this point in the history
  • Loading branch information
Finb committed Dec 18, 2024
1 parent f271853 commit fe4089b
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 10 deletions.
7 changes: 7 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ func main() {
EnvVars: []string{"BARK_SERVER_PROXY_HEADER"},
Value: "",
},
&cli.IntFlag{
Name: "max-batch-push-count",
Usage: "Maximum number of batch pushes allowed, -1 means no limit",
EnvVars: []string{"BARK_SERVER_MAX_BATCH_PUSH_COUNT"},
Value: -1,
Action: func(ctx *cli.Context, v int) error { SetMaxBatchPushCount(v); return nil },
},
&cli.IntFlag{
Name: "max-apns-client-count",
Usage: "Maximum number of APNs client connections",
Expand Down
100 changes: 90 additions & 10 deletions route_push.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package main

import (
"fmt"
"net/url"
"strings"
"sync"

"github.com/gofiber/fiber/v2/utils"

Expand All @@ -11,6 +13,9 @@ import (
"github.com/gofiber/fiber/v2"
)

// Maximum number of batch pushes allowed, -1 means no limit
var maxBatchPushCount = -1

func init() {
// V2 API
registerRoute("push", func(router fiber.Router) {
Expand All @@ -33,6 +38,11 @@ func init() {
})
}

// Set the maximum number of batch pushes allowed
func SetMaxBatchPushCount(count int) {
maxBatchPushCount = count
}

func routeDoPush(c *fiber.Ctx) error {
// Get content-type
contentType := utils.ToLower(utils.UnsafeString(c.Request().Header.ContentType()))
Expand Down Expand Up @@ -61,7 +71,12 @@ func routeDoPush(c *fiber.Ctx) error {
}
}

return push(c, params)
code, err := push(c, params)
if err != nil {
return c.Status(code).JSON(failed(code, err.Error()))
} else {
return c.JSON(success())
}
}

func routeDoPushV2(c *fiber.Ctx) error {
Expand All @@ -74,10 +89,75 @@ func routeDoPushV2(c *fiber.Ctx) error {
c.Request().URI().QueryArgs().VisitAll(func(key, value []byte) {
params[strings.ToLower(string(key))] = string(value)
})
return push(c, params)

var deviceKeys []string
// Get the device_keys array from params
if keys, ok := params["device_keys"]; ok {
switch keys := keys.(type) {
case string:
deviceKeys = strings.Split(keys, ",")
case []interface{}:
for _, key := range keys {
deviceKeys = append(deviceKeys, fmt.Sprint(key))
}
default:
return c.Status(400).JSON(failed(400, "invalid type for device_keys"))
}
delete(params, "device_keys")
}

count := len(deviceKeys)

if count == 0 {
// Single push
code, err := push(c, params)
if err != nil {
return c.Status(code).JSON(failed(code, err.Error()))
} else {
return c.JSON(success())
}
} else {
// Batch push
if count > maxBatchPushCount && maxBatchPushCount != -1 {
return c.Status(400).JSON(failed(400, "batch push count exceeds the maximum limit: %d", maxBatchPushCount))
}

var wg sync.WaitGroup
result := make([]map[string]interface{}, count)
var mu sync.Mutex

for i := 0; i < count; i++ {
// Copy params
newParams := make(map[string]interface{})
for k, v := range params {
newParams[k] = v
}
newParams["device_key"] = deviceKeys[i]

wg.Add(1)
go func(i int, newParams map[string]interface{}) {
defer wg.Done()

// Push
code, err := push(c, newParams)

// Save result
mu.Lock()
result[i] = make(map[string]interface{})
if err != nil {
result[i]["message"] = err.Error()
}
result[i]["code"] = code
result[i]["device_key"] = deviceKeys[i]
mu.Unlock()
}(i, newParams)
}
wg.Wait()
return c.JSON(data(result))
}
}

func push(c *fiber.Ctx, params map[string]interface{}) error {
func push(c *fiber.Ctx, params map[string]interface{}) (int, error) {
// default value
msg := apns.PushMessage{
Body: "",
Expand Down Expand Up @@ -123,27 +203,27 @@ func push(c *fiber.Ctx, params map[string]interface{}) error {
if subtitle := c.Params("subtitle"); subtitle != "" {
str, err := url.QueryUnescape(subtitle)
if err != nil {
return err
return 500, err
}
msg.Subtitle = str
}
if title := c.Params("title"); title != "" {
str, err := url.QueryUnescape(title)
if err != nil {
return err
return 500, err
}
msg.Title = str
}
if body := c.Params("body"); body != "" {
str, err := url.QueryUnescape(body)
if err != nil {
return err
return 500, err
}
msg.Body = str
}

if msg.DeviceKey == "" {
return c.Status(400).JSON(failed(400, "device key is empty"))
return 400, fmt.Errorf("device key is empty")
}

if msg.Body == "" && msg.Title == "" && msg.Subtitle == "" {
Expand All @@ -152,13 +232,13 @@ func push(c *fiber.Ctx, params map[string]interface{}) error {

deviceToken, err := db.DeviceTokenByKey(msg.DeviceKey)
if err != nil {
return c.Status(400).JSON(failed(400, "failed to get device token: %v", err))
return 400, fmt.Errorf("failed to get device token: %v", err)
}
msg.DeviceToken = deviceToken

err = apns.Push(&msg)
if err != nil {
return c.Status(500).JSON(failed(500, "push failed: %v", err))
return 500, fmt.Errorf("push failed: %v", err)
}
return c.JSON(success())
return 200, nil
}

0 comments on commit fe4089b

Please sign in to comment.