Skip to content

Commit

Permalink
lint (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc authored Mar 20, 2024
1 parent ffb2f47 commit c764fab
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 22 deletions.
26 changes: 20 additions & 6 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"bytes"
"context"
"errors"
"flag"
"fmt"
"os"
Expand All @@ -18,7 +19,6 @@ import (

"github.com/crowdsecurity/crowdsec/pkg/models"
csbouncer "github.com/crowdsecurity/go-cs-bouncer"

"github.com/crowdsecurity/go-cs-lib/version"

"github.com/crowdsecurity/cs-aws-waf-bouncer/pkg/cfg"
Expand All @@ -30,8 +30,8 @@ var wafInstances = make([]*waf.WAF, 0)
func resourceCleanup() {
for _, w := range wafInstances {
w.Logger.Infof("Cleaning up resources")
err := w.Cleanup()
if err != nil {

if err := w.Cleanup(); err != nil {
log.Errorf("Error cleaning up WAF: %s", err)
}
}
Expand All @@ -45,13 +45,14 @@ func HandleSignals(ctx context.Context) error {
case s := <-signalChan:
switch s {
case syscall.SIGTERM:
return fmt.Errorf("received SIGTERM")
return errors.New("received SIGTERM")
case os.Interrupt: // cross-platform SIGINT
return fmt.Errorf("received interrupt")
return errors.New("received interrupt")
}
case <-ctx.Done():
return ctx.Err()
}

return nil
}

Expand All @@ -70,6 +71,7 @@ func processDecisions(decisions *models.DecisionsStreamResponse, supportedAction
if !slices.Contains(supportedActions, decisionType) {
decisionType = "fallback"
}

if strings.ToLower(*decision.Scope) == "ip" || strings.ToLower(*decision.Scope) == "range" {
if strings.Contains(*decision.Value, ":") {
if !strings.Contains(*decision.Value, "/") {
Expand All @@ -96,6 +98,7 @@ func processDecisions(decisions *models.DecisionsStreamResponse, supportedAction
if !slices.Contains(supportedActions, decisionType) {
decisionType = "fallback"
}

if strings.ToLower(*decision.Scope) == "ip" || strings.ToLower(*decision.Scope) == "range" {
if strings.Contains(*decision.Value, ":") {
if !strings.Contains(*decision.Value, "/") {
Expand Down Expand Up @@ -136,6 +139,7 @@ func Execute() error {
}

configBytes := []byte{}

var err error

if configPath != nil && *configPath != "" {
Expand Down Expand Up @@ -169,12 +173,15 @@ func Execute() error {
if *testConfig {
for _, wafConfig := range config.WebACLConfig {
log.Debugf("Create WAF instance with config: %+v", wafConfig)

_, err := waf.NewWaf(wafConfig)
if err != nil {
return fmt.Errorf("configuration error: %w", err)
}
}

log.Info("valid config")

return nil
}

Expand All @@ -198,17 +205,21 @@ func Execute() error {

for _, wafConfig := range config.WebACLConfig {
log.Debugf("Create WAF instance with config: %+v", wafConfig)

w, err := waf.NewWaf(wafConfig)
if err != nil {
return fmt.Errorf("could not create waf instance: %w", err)
}

err = w.Init()
if err != nil {
if os.Getenv("CS_AWS_WAF_BOUNCER_TESTING") == "" {
return fmt.Errorf("could not initialize waf instance: %w", err)
}

log.Errorf("could not initialize waf instance: %v+", err)
}

wafInstances = append(wafInstances, w)
}

Expand All @@ -220,7 +231,7 @@ func Execute() error {

g.Go(func() error {
bouncer.Run(ctx)
return fmt.Errorf("bouncer stream halted")
return errors.New("bouncer stream halted")
})

if config.Daemon {
Expand All @@ -232,13 +243,16 @@ func Execute() error {

g.Go(func() error {
log.Info("Starting processing decisions")

for {
select {
case <-ctx.Done():
log.Info("terminating bouncer process")

for _, w := range wafInstances {
w.T.Kill(nil)
}

return nil
case decisions := <-bouncer.Stream:
log.Info("Polling decisions")
Expand Down
22 changes: 18 additions & 4 deletions pkg/cfg/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ var validScopes = []string{"REGIONAL", "CLOUDFRONT"}
var validIpHeaderPosition = []string{"FIRST", "LAST", "ANY"}

func getConfigFromEnv(config *bouncerConfig) {
var key string
var value string
var acl *AclConfig
var err error
var (
key string
value string
acl *AclConfig
err error
)

acls := make(map[byte]*AclConfig, 0)

for _, env := range os.Environ() {
Expand All @@ -73,12 +76,14 @@ func getConfigFromEnv(config *bouncerConfig) {
if k2[0] < '0' || k2[0] > '9' || len(k2) < 3 {
log.Warnf("Invalid name for %s: BOUNCER_WAF_CONFIG_* must be in the form BOUNCER_WAF_CONFIG_0_XXX, BOUNCER_WAF_CONFIG_1_XXX", key)
}

if _, ok := acls[k2[0]]; !ok {
acl = &AclConfig{}
acls[k2[0]] = acl
} else {
acl = acls[k2[0]]
}

k2 = k2[2:]
switch k2 {
case "WEB_ACL_NAME":
Expand Down Expand Up @@ -177,6 +182,7 @@ func getConfigFromEnv(config *bouncerConfig) {
}
}
}

for _, v := range acls {
config.WebACLConfig = append(config.WebACLConfig, *v)
}
Expand Down Expand Up @@ -216,16 +222,20 @@ func (c *bouncerConfig) ValidateAndSetDefaults() error {
if len(c.WebACLConfig) == 0 {
return fmt.Errorf("waf_config is required")
}

for _, c := range c.WebACLConfig {
if c.FallbackAction == "" {
return fmt.Errorf("fallback_action is required")
}

if !slices.Contains(ValidActions, c.FallbackAction) {
return fmt.Errorf("fallback_action must be one of %v", ValidActions)
}

if c.RuleGroupName == "" {
return fmt.Errorf("rule_group_name is required")
}

if c.Scope == "" {
return fmt.Errorf("scope is required")
}
Expand All @@ -241,9 +251,11 @@ func (c *bouncerConfig) ValidateAndSetDefaults() error {
if !slices.Contains(validScopes, c.Scope) {
return fmt.Errorf("scope must be one of %v", validScopes)
}

if c.IpsetPrefix == "" {
return fmt.Errorf("ipset_prefix is required")
}

if c.Region == "" && strings.ToUpper(c.Scope) == "REGIONAL" {
return fmt.Errorf("region is required when scope is REGIONAL")
}
Expand Down Expand Up @@ -274,9 +286,11 @@ func (c *bouncerConfig) ValidateAndSetDefaults() error {
func MergedConfig(configPath string) ([]byte, error) {
patcher := yamlpatch.NewPatcher(configPath, ".local")
data, err := patcher.MergedPatchContent()

if err != nil {
return nil, err
}

return data, nil
}

Expand Down
3 changes: 3 additions & 0 deletions pkg/cfg/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,17 @@ func (c *LoggingConfig) validate() error {
if c.LogMedia != "stdout" && c.LogMedia != "file" {
return fmt.Errorf("log_media should be either 'stdout' or 'file'")
}

return nil
}

func (c *LoggingConfig) setup(fileName string) error {
c.setDefaults()

if err := c.validate(); err != nil {
return err
}

log.SetLevel(*c.LogLevel)

if c.LogMedia == "stdout" {
Expand Down
16 changes: 16 additions & 0 deletions pkg/waf/ipset.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ func (w *WAFIpSet) Add(ip string) {
if w.Size() >= 10000 {
return
}

if w.Contains(ip) {
return
}

w.ips = append(w.ips, ip)
w.stale = true
}
Expand All @@ -53,6 +55,7 @@ func (w *WAFIpSet) ContainsAll(ips []string) bool {
return false
}
}

return true
}

Expand Down Expand Up @@ -86,6 +89,7 @@ func (w *WAFIpSet) ToStatement(ipHeader string, ipHeaderPosition string) *wafv2.
ARN: aws.String(w.arn),
}
}

return &wafv2.IPSetReferenceStatement{
ARN: aws.String(w.arn),
IPSetForwardedIPConfig: &wafv2.IPSetForwardedIPConfig{
Expand All @@ -98,9 +102,11 @@ func (w *WAFIpSet) ToStatement(ipHeader string, ipHeaderPosition string) *wafv2.

func (w *WAFIpSet) getIPSet() (*wafv2.IPSet, *string, error) {
w.logger.Debugf("Getting IPSet %s", w.name)

if w.id == "" {
return nil, nil, &wafv2.WAFNonexistentItemException{}
}

r, err := w.client.GetIPSet(&wafv2.GetIPSetInput{
Name: aws.String(w.name),
Scope: aws.String(w.scope),
Expand All @@ -109,12 +115,14 @@ func (w *WAFIpSet) getIPSet() (*wafv2.IPSet, *string, error) {
if err != nil {
return nil, nil, err
}

return r.IPSet, r.LockToken, nil
}

func (w *WAFIpSet) createIpSet() (*wafv2.IPSetSummary, error) {
w.logger.Infof("Creating IPSet %s", w.name)
w.logger.Tracef("Set name: %s | Type: %s | Decision: %s | Scope: %s | %d IPS", w.name, w.ipType, w.decisionType, w.scope, w.Size())

r, err := w.client.CreateIPSet(&wafv2.CreateIPSetInput{
Name: aws.String(w.name),
Addresses: aws.StringSlice(w.ips),
Expand All @@ -124,6 +132,7 @@ func (w *WAFIpSet) createIpSet() (*wafv2.IPSetSummary, error) {
if err != nil {
return nil, err
}

return r.Summary, nil
}

Expand All @@ -144,11 +153,13 @@ func (w *WAFIpSet) DeleteIpSet() error {
if err != nil {
return err
}

return nil
}

func (w *WAFIpSet) Commit() error {
w.logger.Infof("Updating IPSet %s", w.name)

currSet, token, err := w.getIPSet()
if err != nil {
switch err.(type) {
Expand All @@ -157,11 +168,13 @@ func (w *WAFIpSet) Commit() error {
return err
}
}

if currSet == nil {
summary, err := w.createIpSet()
if err != nil {
return fmt.Errorf("failed to create IPSet %s: %w", w.name, err)
}

w.arn = *summary.ARN
w.id = *summary.Id
} else {
Expand All @@ -176,13 +189,16 @@ func (w *WAFIpSet) Commit() error {
return err
}
}

w.stale = false

return nil
}

func NewIpSet(setPrefix string, ipType string, decisionType string, scope string, client *wafv2.WAFV2) *WAFIpSet {
u := uuid.New()
setName := setPrefix + "-" + ipType + "-" + decisionType + "-" + u.String()

return &WAFIpSet{
name: setName,
ipType: ipType,
Expand Down
Loading

0 comments on commit c764fab

Please sign in to comment.