Skip to content

Commit

Permalink
Merge pull request #57 from DrmagicE/refactor
Browse files Browse the repository at this point in the history
Refactor: refactor the way to expose grpc and http endpoint
  • Loading branch information
DrmagicE authored Jan 28, 2021
2 parents 5ee07a7 + 4be42cd commit afcf8c8
Show file tree
Hide file tree
Showing 33 changed files with 1,046 additions and 336 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ FROM alpine:3.12

WORKDIR /gmqttd
COPY --from=builder /go/src/github.com/DrmagicE/gmqtt/build/gmqttd .
RUN mkdir /etc/gmqtt
COPY ./cmd/gmqttd/default_config.yml /etc/gmqtt/gmqttd.yml
ENV PATH=$PATH:/gmqttd
RUN chmod +x gmqttd
ENTRYPOINT ["gmqttd","start"]
Expand Down
40 changes: 16 additions & 24 deletions cmd/gmqttd/command/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ import (
)

var (
DefaultConfigFile string
ConfigFile string
logger *zap.Logger
ConfigFile string
logger *zap.Logger
)

func must(err error) {
if err != nil {
fmt.Println(err)
fmt.Fprint(os.Stderr, err)
os.Exit(1)
}
}
Expand All @@ -53,8 +52,10 @@ func installSignal(srv server.Server) {
srv.ApplyConfig(c)
logger.Info("gmqtt reloaded")
case <-stopSignalCh:
srv.Stop(context.Background())
return
err := srv.Stop(context.Background())
if err != nil {
fmt.Fprint(os.Stderr, err.Error())
}
}
}

Expand All @@ -69,15 +70,15 @@ func GetListeners(c config.Config) (tcpListeners []net.Listener, websockets []*s
Path: v.Websocket.Path,
}
if v.TLSOptions != nil {
ws.KeyFile = v.KeyFile
ws.CertFile = v.CertFile
ws.KeyFile = v.Key
ws.CertFile = v.Cert
}
websockets = append(websockets, ws)
continue
}
if v.TLSOptions != nil {
var cert tls.Certificate
cert, err = tls.LoadX509KeyPair(v.CertFile, v.KeyFile)
cert, err = tls.LoadX509KeyPair(v.Cert, v.Key)
if err != nil {
return
}
Expand All @@ -94,39 +95,30 @@ func GetListeners(c config.Config) (tcpListeners []net.Listener, websockets []*s

// NewStartCmd creates a *cobra.Command object for start command.
func NewStartCmd() *cobra.Command {
cfg := config.DefaultConfig()
cmd := &cobra.Command{
Use: "start",
Short: "Start gmqtt broker",
Run: func(cmd *cobra.Command, args []string) {
var err error
must(err)
c, err := config.ParseConfig(ConfigFile)
var useDefault bool
if os.IsNotExist(err) {
if DefaultConfigFile != ConfigFile {
fmt.Println(err)
return
}
// if config file not exist, use default configration.
c = cfg
useDefault = true
must(err)
} else {
must(err)
}
err = c.Validate()
must(err)
pid, err := pidfile.New(c.PidFile)
must(err)
if err != nil {
must(fmt.Errorf("open pid file failed: %s", err))
}
defer pid.Remove()
tcpListeners, websockets, err := GetListeners(c)
must(err)
l, err := c.GetLogger(c.Log)
must(err)
logger = l
if useDefault {
l.Warn("config file not exist, use default configration")
}
s := server.New(
server.WithConfig(c),
server.WithTCPListener(tcpListeners...),
Expand All @@ -139,13 +131,13 @@ func NewStartCmd() *cobra.Command {
os.Exit(1)
return
}
go installSignal(s)
err = s.Run()
if err != nil {
fmt.Println(err)
fmt.Fprint(os.Stderr, err.Error())
os.Exit(1)
return
}
installSignal(s)
},
}
return cmd
Expand Down
11 changes: 11 additions & 0 deletions cmd/gmqttd/config_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// +build !windows

package main

var (
DefaultConfigDir = "/etc/gmqtt"
)

func getDefaultConfigDir() (string, error) {
return DefaultConfigDir, nil
}
12 changes: 12 additions & 0 deletions cmd/gmqttd/config_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// +build windows

package main

import (
"os"
"path/filepath"
)

func getDefaultConfigDir() (string, error) {
return filepath.Join(os.Getenv("programdata"), "gmqtt"), nil
}
40 changes: 28 additions & 12 deletions cmd/gmqttd/default_config.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,35 @@
# Path to pid file, default to /var/run/gmqttd.pid
# pid_file:

listeners:
# bind address
- address: ":1883"
# tls setting
# tls:
# cert_file: "path_to_cert_file"
# key_file: "path_to_key_file"
# tls:
# cacert: "path_to_ca_cert_file"
# cert: "path_to_cert_file"
# key: "path_to_key_file"

- address: ":8883"
# websocket setting
websocket:
path: "/"

api:
grpc:
- address: "unix:///var/run/gmqttd.sock" # The gRPC server listen address.
# tls:
# cacert: "path_to_ca_cert_file"
# cert: "path_to_cert_file"
# key: "path_to_key_file"
http:
# The HTTP server listen address. This is a reverse-proxy server in front of gRPC server.
- address: "tcp://127.0.0.1:8083"
map: "unix:///var/run/gmqttd.sock" # The backend gRPC server endpoint,
# tls:
# cacert: "path_to_ca_cert_file"
# cert: "path_to_cert_file"
# key: "path_to_key_file"

mqtt:
session_expiry: 2h
session_expiry_check_timer: 20s
Expand Down Expand Up @@ -56,23 +76,19 @@ plugins:
prometheus:
path: "/metrics"
listen_address: ":8082"
admin:
http:
enable: true
addr: :8083
grpc:
addr: 8084
auth:
# Password hash type. (plain | md5 | sha256 | bcrypt)
# Default to MD5.
hash: md5
# The file to store password. Default to $HOME/gmqtt_password.yml
# The file to store password. If it is a relative path, it locates in the same directory as the config file.
# (e.g: ./gmqtt_password => /etc/gmqtt/gmqtt_password.yml)
# Default to ./gmqtt_password.yml
# password_file:

# plugin loading orders
plugin_order:
# Uncomment auth to enable authentication.
#- auth
# - auth
- prometheus
- admin
log:
Expand Down
13 changes: 6 additions & 7 deletions cmd/gmqttd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"net/http"
_ "net/http/pprof"
"os"
"path"

"github.com/mitchellh/go-homedir"
"github.com/spf13/cobra"

"github.com/DrmagicE/gmqtt/cmd/gmqttd/command"
Expand All @@ -25,17 +25,16 @@ var (

func must(err error) {
if err != nil {
fmt.Println(err)
fmt.Fprint(os.Stderr, err.Error())
os.Exit(1)
}
}

func init() {
d, err := homedir.Dir()
configDir, err := getDefaultConfigDir()
must(err)
command.DefaultConfigFile = d + "/gmqtt.yml"
rootCmd.PersistentFlags().StringVarP(&command.ConfigFile, "config", "c", command.DefaultConfigFile, "The configuration file path")

command.ConfigFile = path.Join(configDir, "gmqttd.yml")
rootCmd.PersistentFlags().StringVarP(&command.ConfigFile, "config", "c", command.ConfigFile, "The configuration file path")
rootCmd.AddCommand(command.NewStartCmd())
rootCmd.AddCommand(command.NewReloadCommand())
}
Expand All @@ -51,7 +50,7 @@ func main() {
http.ListenAndServe(":6060", nil)
}()
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
fmt.Fprint(os.Stderr, err.Error())
os.Exit(1)
}
}
77 changes: 77 additions & 0 deletions config/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package config

import (
"fmt"
"net"
"strings"
)

// API is the configuration for API server.
// The API server use gRPC-gateway to provide both gRPC and HTTP endpoints.
type API struct {
// GRPC is the gRPC endpoint configuration.
GRPC []*Endpoint `yaml:"grpc"`
// HTTP is the HTTP endpoint configuration.
HTTP []*Endpoint `yaml:"http"`
}

// Endpoint represents a gRPC or HTTP server endpoint.
type Endpoint struct {
// Address is the bind address of the endpoint.
// Format: [tcp|unix://][<host>]:<port>
// e.g :
// * unix:///var/run/gmqttd.sock
// * tcp://127.0.0.1:8080
// * :8081 (equal to tcp://:8081)
Address string `yaml:"address"`
// Map maps the HTTP endpoint to gRPC endpoint.
// Must be set if the endpoint is representing a HTTP endpoint.
Map string `yaml:"map"`
// TLS is the tls configuration.
TLS *TLSOptions `yaml:"tls"`
}

var DefaultAPI API

func (a API) validateAddress(address string, fieldName string) error {
if address == "" {
return fmt.Errorf("%s cannot be empty", fieldName)
}
epParts := strings.SplitN(address, "://", 2)
if len(epParts) == 1 && epParts[0] != "" {
epParts = []string{"tcp", epParts[0]}
}
if len(epParts) != 0 {
switch epParts[0] {
case "tcp":
_, _, err := net.SplitHostPort(epParts[1])
if err != nil {
return fmt.Errorf("invalid %s: %s", fieldName, err.Error())
}
case "unix":
default:
return fmt.Errorf("invalid %s schema: %s", fieldName, epParts[0])
}
}
return nil
}

func (a API) Validate() error {
for _, v := range a.GRPC {
err := a.validateAddress(v.Address, "endpoint")
if err != nil {
return err
}
}
for _, v := range a.HTTP {
err := a.validateAddress(v.Address, "endpoint")
if err != nil {
return err
}
err = a.validateAddress(v.Map, "map")
if err != nil {
return err
}
}
return nil
}
Loading

0 comments on commit afcf8c8

Please sign in to comment.