Skip to content

Commit

Permalink
feat: add bandwidth limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
zakuwaki committed Jul 4, 2023
1 parent 07ce5e0 commit cc4e41f
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 4 deletions.
1 change: 1 addition & 0 deletions adapter/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ type Rule interface {
Match(metadata *InboundContext) bool
Outbound() string
String() string
Limiters() []string
}

type DNSRule interface {
Expand Down
4 changes: 4 additions & 0 deletions box.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/sagernet/sing-box/experimental"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/inbound"
"github.com/sagernet/sing-box/limiter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/outbound"
Expand Down Expand Up @@ -72,6 +73,9 @@ func New(options Options) (*Box, error) {
if err != nil {
return nil, E.Cause(err, "create log factory")
}
if len(options.Limiters) > 0 {
ctx = limiter.WithDefault(ctx, logFactory.NewLogger("limiter"), options.Limiters)
}
router, err := route.NewRouter(
ctx,
logFactory,
Expand Down
103 changes: 103 additions & 0 deletions limiter/builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package limiter

import (
"context"
"fmt"
"net"
"sync"

"github.com/dustin/go-humanize"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/service"
)

const (
limiterDefault = "default"
limiterUser = "user"
limiterInbound = "inbound"
)

var _ Manager = (*defaultManager)(nil)

type defaultManager struct {
mp *sync.Map
}

func WithDefault(ctx context.Context, logger log.ContextLogger, options []option.Limiter) context.Context {
m := &defaultManager{mp: &sync.Map{}}
for i, option := range options {
if err := m.createLimiter(ctx, option); err != nil {
logger.ErrorContext(ctx, fmt.Sprintf("id=%d, %s", i, err))
} else {
logger.InfoContext(ctx, fmt.Sprintf("id=%d, tag=%s, users=%v, inbounds=%v, download=%s, upload=%s",
i, option.Tag, option.AuthUser, option.Inbound, option.Download, option.Upload))
}
}
return service.ContextWith[Manager](ctx, m)
}

func buildKey(prefix string, tag string) string {
return fmt.Sprintf("%s-%s", prefix, tag)
}

func (m *defaultManager) createLimiter(ctx context.Context, option option.Limiter) (err error) {
var download, upload uint64
if len(option.Download) > 0 {
download, err = humanize.ParseBytes(option.Download)
if err != nil {
return err
}
}
if len(option.Upload) > 0 {
upload, err = humanize.ParseBytes(option.Upload)
}
if download == 0 && upload == 0 {
return E.New("bandwith must be set")
}
l := newLimiter(download, upload)
valid := false
if len(option.Tag) > 0 {
valid = true
m.mp.Store(buildKey(limiterDefault, option.Tag), newLimiter(download, upload))
}
if len(option.AuthUser) > 0 {
valid = true
for _, user := range option.AuthUser {
m.mp.Store(buildKey(limiterUser, user), l)
}
}
if len(option.Inbound) > 0 {
valid = true
for _, inbound := range option.Inbound {
m.mp.Store(buildKey(limiterInbound, inbound), l)
}
}
if !valid {
return E.New("tag or constraint must be set")
}
return
}

func (m *defaultManager) LoadLimiters(tags []string, user, inbound string) (limiters []*limiter) {
for _, t := range tags {
if v, ok := m.mp.Load(buildKey(limiterDefault, t)); ok {
limiters = append(limiters, v.(*limiter))
}
}
if v, ok := m.mp.Load(buildKey(limiterUser, user)); ok {
limiters = append(limiters, v.(*limiter))
}
if v, ok := m.mp.Load(buildKey(limiterInbound, inbound)); ok {
limiters = append(limiters, v.(*limiter))
}
return
}

func (m *defaultManager) NewConnWithLimiters(ctx context.Context, conn net.Conn, limiters []*limiter) net.Conn {
for _, limiter := range limiters {
conn = &connWithLimiter{Conn: conn, limiter: limiter, ctx: ctx}
}
return conn
}
77 changes: 77 additions & 0 deletions limiter/limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package limiter

import (
"context"
"net"

"golang.org/x/time/rate"
)

type limiter struct {
downloadLimiter *rate.Limiter
uploadLimiter *rate.Limiter
}

func newLimiter(download, upload uint64) *limiter {
var downloadLimiter, uploadLimiter *rate.Limiter
if download > 0 {
downloadLimiter = rate.NewLimiter(rate.Limit(float64(download)), int(download))
}
if upload > 0 {
uploadLimiter = rate.NewLimiter(rate.Limit(float64(upload)), int(upload))
}
return &limiter{downloadLimiter: downloadLimiter, uploadLimiter: uploadLimiter}
}

type connWithLimiter struct {
net.Conn
limiter *limiter
ctx context.Context
}

func (conn *connWithLimiter) Read(p []byte) (n int, err error) {
if conn.limiter == nil || conn.limiter.downloadLimiter == nil {
return conn.Conn.Read(p)
}
b := conn.limiter.downloadLimiter.Burst()
if b < len(p) {
p = p[:b]
}
n, err = conn.Conn.Read(p)
if err != nil {
return
}
err = conn.limiter.downloadLimiter.WaitN(conn.ctx, n)
if err != nil {
return
}
return
}

func (conn *connWithLimiter) Write(p []byte) (n int, err error) {
if conn.limiter == nil || conn.limiter.uploadLimiter == nil {
return conn.Conn.Write(p)
}
var nn int
b := conn.limiter.uploadLimiter.Burst()
for {
end := len(p)
if end == 0 {
break
}
if b < len(p) {
end = b
}
err = conn.limiter.uploadLimiter.WaitN(conn.ctx, end)
if err != nil {
return
}
nn, err = conn.Conn.Write(p[:end])
n += nn
if err != nil {
return
}
p = p[end:]
}
return
}
11 changes: 11 additions & 0 deletions limiter/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package limiter

import (
"context"
"net"
)

type Manager interface {
LoadLimiters(tags []string, user, inbound string) []*limiter
NewConnWithLimiters(ctx context.Context, conn net.Conn, limiters []*limiter) net.Conn
}
1 change: 1 addition & 0 deletions option/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type _Options struct {
Inbounds []Inbound `json:"inbounds,omitempty"`
Outbounds []Outbound `json:"outbounds,omitempty"`
Route *RouteOptions `json:"route,omitempty"`
Limiters []Limiter `json:"limiters,omitempty"`
Experimental *ExperimentalOptions `json:"experimental,omitempty"`
}

Expand Down
9 changes: 9 additions & 0 deletions option/limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package option

type Limiter struct {
Tag string `json:"tag"`
Download string `json:"download,omitempty"`
Upload string `json:"upload,omitempty"`
AuthUser Listable[string] `json:"auth_user,omitempty"`
Inbound Listable[string] `json:"inbound,omitempty"`
}
10 changes: 6 additions & 4 deletions option/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ type DefaultRule struct {
ClashMode string `json:"clash_mode,omitempty"`
Invert bool `json:"invert,omitempty"`
Outbound string `json:"outbound,omitempty"`
Limiter Listable[string] `json:"limiter,omitempty"`
}

func (r DefaultRule) IsValid() bool {
Expand All @@ -90,10 +91,11 @@ func (r DefaultRule) IsValid() bool {
}

type LogicalRule struct {
Mode string `json:"mode"`
Rules []DefaultRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"`
Outbound string `json:"outbound,omitempty"`
Mode string `json:"mode"`
Rules []DefaultRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"`
Outbound string `json:"outbound,omitempty"`
Limiter Listable[string] `json:"limiter,omitempty"`
}

func (r LogicalRule) IsValid() bool {
Expand Down
18 changes: 18 additions & 0 deletions route/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/sagernet/sing-box/common/sniff"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/limiter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/ntp"
"github.com/sagernet/sing-box/option"
Expand All @@ -38,6 +39,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/service"
)

var _ adapter.Router = (*Router)(nil)
Expand Down Expand Up @@ -80,6 +82,7 @@ type Router struct {
timeService adapter.TimeService
clashServer adapter.ClashServer
v2rayServer adapter.V2RayServer
limiterManager limiter.Manager
platformInterface platform.Interface
}

Expand Down Expand Up @@ -487,6 +490,9 @@ func (r *Router) Start() error {
return E.Cause(err, "initialize time service")
}
}
if limiterManger := service.FromContext[limiter.Manager](r.ctx); limiterManger != nil {
r.limiterManager = limiterManger
}
return nil
}

Expand Down Expand Up @@ -688,6 +694,18 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
if !common.Contains(detour.Network(), N.NetworkTCP) {
return E.New("missing supported outbound, closing connection")
}

if r.limiterManager != nil {
var limiterTags []string
if matchedRule != nil {
limiterTags = matchedRule.Limiters()
}
limiters := r.limiterManager.LoadLimiters(limiterTags, metadata.User, metadata.Inbound)
if len(limiters) > 0 {
conn = r.limiterManager.NewConnWithLimiters(ctx, conn, limiters)
}
}

if r.clashServer != nil {
trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, matchedRule)
defer tracker.Leave()
Expand Down
10 changes: 10 additions & 0 deletions route/rule_abstract.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type abstractDefaultRule struct {
allItems []RuleItem
invert bool
outbound string
limiters []string
}

func (r *abstractDefaultRule) Type() string {
Expand Down Expand Up @@ -126,6 +127,10 @@ func (r *abstractDefaultRule) Outbound() string {
return r.outbound
}

func (r *abstractDefaultRule) Limiters() []string {
return r.limiters
}

func (r *abstractDefaultRule) String() string {
if !r.invert {
return strings.Join(F.MapToString(r.allItems), " ")
Expand All @@ -139,6 +144,7 @@ type abstractLogicalRule struct {
mode string
invert bool
outbound string
limiters []string
}

func (r *abstractLogicalRule) Type() string {
Expand Down Expand Up @@ -191,6 +197,10 @@ func (r *abstractLogicalRule) Outbound() string {
return r.outbound
}

func (r *abstractLogicalRule) Limiters() []string {
return r.limiters
}

func (r *abstractLogicalRule) String() string {
var op string
switch r.mode {
Expand Down
6 changes: 6 additions & 0 deletions route/rule_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.Limiter) > 0 {
rule.limiters = append(rule.limiters, options.Limiter...)
}
return rule, nil
}

Expand Down Expand Up @@ -216,5 +219,8 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt
}
r.rules[i] = rule
}
if len(options.Limiter) > 0 {
r.limiters = append(r.limiters, options.Limiter...)
}
return r, nil
}

0 comments on commit cc4e41f

Please sign in to comment.