Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nikooo777 committed Oct 31, 2023
1 parent 5468f62 commit c67f69a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 16 deletions.
2 changes: 1 addition & 1 deletion firewall/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var whitelist = map[string]bool{
"51.210.0.171": true,
}

func IsIpAbusingResources(ip string, endpoint string) (bool, int) {
func CheckAndRateLimitIp(ip string, endpoint string) (bool, int) {
if ip == "" {
return false, 0
}
Expand Down
20 changes: 8 additions & 12 deletions firewall/firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"strconv"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestCheckIPAccess(t *testing.T) {
Expand All @@ -12,24 +14,18 @@ func TestCheckIPAccess(t *testing.T) {
WindowSize = 7 * time.Second
// Test the first five accesses for an IP don't exceed the limit
for i := 1; i <= 6; i++ {
result, _ := IsIpAbusingResources(ip, endpoint+strconv.Itoa(i))
if result {
t.Errorf("Expected result to be false, got %v for endpoint %s", result, endpoint+strconv.Itoa(i))
}
result, _ := CheckAndRateLimitIp(ip, endpoint+strconv.Itoa(i))
assert.False(t, result, "Expected result to be false, got %v for endpoint %s", result, endpoint+strconv.Itoa(i))
}

// Test the sixth access for an IP exceeds the limit
result, _ := IsIpAbusingResources(ip, endpoint+"7")
if !result {
t.Errorf("Expected result to be true, got %v for endpoint %s", result, endpoint+"6")
}
result, _ := CheckAndRateLimitIp(ip, endpoint+"7")
assert.True(t, result, "Expected result to be true, got %v for endpoint %s", result, endpoint+"7")

// Wait for the window size to elapse
time.Sleep(WindowSize)

// Test the access for an IP after the window size elapses doesn't exceed the limit
result, _ = IsIpAbusingResources(ip, endpoint+"7")
if result {
t.Errorf("Expected result to be false, got %v for endpoint %s", result, endpoint+"7")
}
result, _ = CheckAndRateLimitIp(ip, endpoint+"7")
assert.False(t, result, "Expected result to be false, got %v for endpoint %s", result, endpoint+"7")
}
6 changes: 3 additions & 3 deletions player/http_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ func (h *RequestHandler) Handle(c *gin.Context) {

flagged := true
for header, v := range c.Request.Header {
//check if the header is not one we want to check, so we skip it
if header != "User-Agent" && header != "Referer" && header != "Origin" && header != "X-Requested-With" && !allowedSpecialHeaders[strings.ToLower(header)] {
hasHeaderToCheck := header != "User-Agent" && header != "Referer" && header != "Origin" && header != "X-Requested-With"
if hasHeaderToCheck && !allowedSpecialHeaders[strings.ToLower(header)] {
continue
}
if strings.ToLower(header) == "origin" && allowedOrigins[v[0]] {
Expand Down Expand Up @@ -257,7 +257,7 @@ func (h *RequestHandler) Handle(c *gin.Context) {
c.String(http.StatusForbidden, "this content cannot be accessed")
return
}
abusiveIP, abuseCount := firewall.IsIpAbusingResources(ip, stream.ClaimID)
abusiveIP, abuseCount := firewall.CheckAndRateLimitIp(ip, stream.ClaimID)
if abusiveIP {
Logger.Warnf("IP %s is abusing resources (count: %d): %s - %s", ip, abuseCount, stream.ClaimID, stream.claim.Name)
if abuseCount > 10 {
Expand Down

0 comments on commit c67f69a

Please sign in to comment.