Skip to content

Commit

Permalink
Add inline rule-set & Add reload for local rule-set
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Jul 13, 2024
1 parent ec167cf commit fafc7e7
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 270 deletions.
5 changes: 4 additions & 1 deletion cmd/sing-box/cmd_rule_set_compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ func compileRuleSet(sourcePath string) error {
if err != nil {
return err
}
ruleSet := plainRuleSet.Upgrade()
ruleSet, err := plainRuleSet.Upgrade()
if err != nil {
return err
}
var outputPath string
if flagRuleSetCompileOutput == flagRuleSetCompileDefaultOutput {
if strings.HasSuffix(sourcePath, ".json") {
Expand Down
5 changes: 4 additions & 1 deletion cmd/sing-box/cmd_rule_set_match.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ func ruleSetMatch(sourcePath string, domain string) error {
if err != nil {
return err
}
plainRuleSet = compat.Upgrade()
plainRuleSet, err = compat.Upgrade()
if err != nil {
return err
}
case C.RuleSetFormatBinary:
plainRuleSet, err = srs.Read(bytes.NewReader(content), false)
if err != nil {
Expand Down
185 changes: 58 additions & 127 deletions common/tls/ech_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ import (
"strings"

cftls "github.com/sagernet/cloudflare-tls"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"

"github.com/fsnotify/fsnotify"
)

type echServerConfig struct {
Expand All @@ -26,9 +25,8 @@ type echServerConfig struct {
key []byte
certificatePath string
keyPath string
watcher *fsnotify.Watcher
echKeyPath string
echWatcher *fsnotify.Watcher
watcher *fswatch.Watcher
}

func (c *echServerConfig) ServerName() string {
Expand Down Expand Up @@ -66,159 +64,92 @@ func (c *echServerConfig) Clone() Config {
}

func (c *echServerConfig) Start() error {
if c.certificatePath != "" && c.keyPath != "" {
err := c.startWatcher()
if err != nil {
c.logger.Warn("create fsnotify watcher: ", err)
}
}
if c.echKeyPath != "" {
err := c.startECHWatcher()
if err != nil {
c.logger.Warn("create fsnotify watcher: ", err)
}
err := c.startWatcher()
if err != nil {
c.logger.Warn("create credentials watcher: ", err)
}
return nil
}

func (c *echServerConfig) startWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
var watchPath []string
if c.certificatePath != "" {
err = watcher.Add(c.certificatePath)
if err != nil {
return err
}
watchPath = append(watchPath, c.certificatePath)
}
if c.keyPath != "" {
err = watcher.Add(c.keyPath)
if err != nil {
return err
}
watchPath = append(watchPath, c.keyPath)
}
if c.echKeyPath != "" {
watchPath = append(watchPath, c.echKeyPath)
}
if len(watchPath) == 0 {
return nil
}
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: watchPath,
Callback: func(path string) {
err := c.credentialsUpdated(path)
if err != nil {
c.logger.Error(E.Cause(err, "reload credentials from ", path))
}
},
})
if err != nil {
return err
}
c.watcher = watcher
go c.loopUpdate()
return nil
}

func (c *echServerConfig) loopUpdate() {
for {
select {
case event, ok := <-c.watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := c.reloadKeyPair()
func (c *echServerConfig) credentialsUpdated(path string) error {
if path == c.certificatePath || path == c.keyPath {
if path == c.certificatePath {
certificate, err := os.ReadFile(c.certificatePath)
if err != nil {
c.logger.Error(E.Cause(err, "reload TLS key pair"))
return err
}
case err, ok := <-c.watcher.Errors:
if !ok {
return
c.certificate = certificate
} else {
key, err := os.ReadFile(c.keyPath)
if err != nil {
return err
}
c.logger.Error(E.Cause(err, "fsnotify error"))
c.key = key
}
}
}

func (c *echServerConfig) reloadKeyPair() error {
if c.certificatePath != "" {
certificate, err := os.ReadFile(c.certificatePath)
keyPair, err := cftls.X509KeyPair(c.certificate, c.key)
if err != nil {
return E.Cause(err, "reload certificate from ", c.certificatePath)
return E.Cause(err, "parse key pair")
}
c.certificate = certificate
}
if c.keyPath != "" {
key, err := os.ReadFile(c.keyPath)
c.config.Certificates = []cftls.Certificate{keyPair}
c.logger.Info("reloaded TLS certificate")
} else {
echKeyContent, err := os.ReadFile(c.echKeyPath)
if err != nil {
return E.Cause(err, "reload key from ", c.keyPath)
return err
}
c.key = key
}
keyPair, err := cftls.X509KeyPair(c.certificate, c.key)
if err != nil {
return E.Cause(err, "reload key pair")
}
c.config.Certificates = []cftls.Certificate{keyPair}
c.logger.Info("reloaded TLS certificate")
return nil
}

func (c *echServerConfig) startECHWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
err = watcher.Add(c.echKeyPath)
if err != nil {
return err
}
c.echWatcher = watcher
go c.loopECHUpdate()
return nil
}

func (c *echServerConfig) loopECHUpdate() {
for {
select {
case event, ok := <-c.echWatcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := c.reloadECHKey()
if err != nil {
c.logger.Error(E.Cause(err, "reload ECH key"))
}
case err, ok := <-c.echWatcher.Errors:
if !ok {
return
}
c.logger.Error(E.Cause(err, "fsnotify error"))
block, rest := pem.Decode(echKeyContent)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return E.New("invalid ECH keys pem")
}
echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes)
if err != nil {
return E.Cause(err, "parse ECH keys")
}
echKeySet, err := cftls.EXP_NewECHKeySet(echKeys)
if err != nil {
return E.Cause(err, "create ECH key set")
}
c.config.ServerECHProvider = echKeySet
c.logger.Info("reloaded ECH keys")
}
}

func (c *echServerConfig) reloadECHKey() error {
echKeyContent, err := os.ReadFile(c.echKeyPath)
if err != nil {
return err
}
block, rest := pem.Decode(echKeyContent)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return E.New("invalid ECH keys pem")
}
echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes)
if err != nil {
return E.Cause(err, "parse ECH keys")
}
echKeySet, err := cftls.EXP_NewECHKeySet(echKeys)
if err != nil {
return E.Cause(err, "create ECH key set")
}
c.config.ServerECHProvider = echKeySet
c.logger.Info("reloaded ECH keys")
return nil
}

func (c *echServerConfig) Close() error {
var err error
if c.watcher != nil {
err = E.Append(err, c.watcher.Close(), func(err error) error {
return E.Cause(err, "close certificate watcher")
})
}
if c.echWatcher != nil {
err = E.Append(err, c.echWatcher.Close(), func(err error) error {
return E.Cause(err, "close ECH key watcher")
return E.Cause(err, "close credentials watcher")
})
}
return err
Expand Down
61 changes: 19 additions & 42 deletions common/tls/std_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ import (
"os"
"strings"

"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"

"github.com/fsnotify/fsnotify"
)

var errInsecureUnused = E.New("tls: insecure unused")
Expand All @@ -27,7 +26,7 @@ type STDServerConfig struct {
key []byte
certificatePath string
keyPath string
watcher *fsnotify.Watcher
watcher *fswatch.Watcher
}

func (c *STDServerConfig) ServerName() string {
Expand Down Expand Up @@ -88,59 +87,37 @@ func (c *STDServerConfig) Start() error {
}

func (c *STDServerConfig) startWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
var watchPath []string
if c.certificatePath != "" {
err = watcher.Add(c.certificatePath)
if err != nil {
return err
}
watchPath = append(watchPath, c.certificatePath)
}
if c.keyPath != "" {
err = watcher.Add(c.keyPath)
if err != nil {
return err
}
watchPath = append(watchPath, c.keyPath)
}
c.watcher = watcher
go c.loopUpdate()
return nil
}

func (c *STDServerConfig) loopUpdate() {
for {
select {
case event, ok := <-c.watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := c.reloadKeyPair()
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: watchPath,
Callback: func(path string) {
err := c.certificateUpdated(path)
if err != nil {
c.logger.Error(E.Cause(err, "reload TLS key pair"))
}
case err, ok := <-c.watcher.Errors:
if !ok {
return
c.logger.Error(err)
}
c.logger.Error(E.Cause(err, "fsnotify error"))
}
},
})
if err != nil {
return err
}
c.watcher = watcher
return nil
}

func (c *STDServerConfig) reloadKeyPair() error {
if c.certificatePath != "" {
func (c *STDServerConfig) certificateUpdated(path string) error {
if path == c.certificatePath {
certificate, err := os.ReadFile(c.certificatePath)
if err != nil {
return E.Cause(err, "reload certificate from ", c.certificatePath)
}
c.certificate = certificate
}
if c.keyPath != "" {
} else if path == c.keyPath {
key, err := os.ReadFile(c.keyPath)
if err != nil {
return E.Cause(err, "reload key from ", c.keyPath)
Expand Down
1 change: 1 addition & 0 deletions constant/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const (
)

const (
RuleSetTypeInline = "inline"
RuleSetTypeLocal = "local"
RuleSetTypeRemote = "remote"
RuleSetVersion1 = 1
Expand Down
Loading

0 comments on commit fafc7e7

Please sign in to comment.