Skip to content

Commit

Permalink
Merge pull request #3 from Reselim/buffer
Browse files Browse the repository at this point in the history
Rewrite to use buffers
  • Loading branch information
Reselim authored Mar 18, 2024
2 parents 31b7810 + 8cd0d65 commit 5928fc1
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 105 deletions.
236 changes: 141 additions & 95 deletions Base64.lua
Original file line number Diff line number Diff line change
@@ -1,119 +1,165 @@
--!native
--!optimize 2

local lookupValueToASCII = {} :: { [number]: number }
local lookupASCIIToValue = {} :: { [number]: number }
local lookupValueToCharacter = buffer.create(64)
local lookupCharacterToValue = buffer.create(256)

local alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
local padding = string.byte("=")

for index = 1, #alphabet do
for index = 1, 64 do
local value = index - 1
local ascii = string.byte(alphabet, index)

lookupValueToASCII[value] = ascii
lookupASCIIToValue[ascii] = value
end

lookupASCIIToValue[string.byte("=")] = 0

local function buildStringFromCodes(values: { number }): string
local chunks = {} :: { string }

for index = 1, #values, 4096 do
table.insert(chunks, string.char(
unpack(values, index, math.min(index + 4096 - 1, #values))
))
end

return table.concat(chunks, "")
local character = string.byte(alphabet, index)

buffer.writeu8(lookupValueToCharacter, value, character)
buffer.writeu8(lookupCharacterToValue, character, value)
end

local function encode(input: string): string
local inputLength = #input
local outputLength = math.ceil(inputLength / 3) * 4

local remainder = inputLength % 3

if remainder == 0 then
-- Since chunks are only 3 characters wide and we're parsing 4 characters, we need
-- to add an extra 0 on the end (which will be discarded anyway)
input ..= string.char(0)
end

local output = table.create(outputLength, 0) :: { number }

for chunkIndex = 0, (outputLength / 4) - (if remainder == 0 then 1 else 2) do
local inputIndex = chunkIndex * 3 + 1
local outputIndex = chunkIndex * 4 + 1

-- Parse this as a single 32-bit integer instead of splitting into multiple and combining after
local chunk = bit32.rshift(string.unpack(">J", input, inputIndex), 8)

output[outputIndex] = lookupValueToASCII[bit32.rshift(chunk, 18)]
output[outputIndex + 1] = lookupValueToASCII[bit32.band(bit32.rshift(chunk, 12), 0b111111)]
output[outputIndex + 2] = lookupValueToASCII[bit32.band(bit32.rshift(chunk, 6), 0b111111)]
output[outputIndex + 3] = lookupValueToASCII[bit32.band(chunk, 0b111111)]
local function encode(input: buffer): buffer
local inputLength = buffer.len(input)
local inputChunks = math.ceil(inputLength / 3)

local outputLength = inputChunks * 4
local output = buffer.create(outputLength)

-- Since we use readu32 and chunks are 3 bytes large, we can't read the last chunk here
for chunkIndex = 1, inputChunks - 1 do
local inputIndex = (chunkIndex - 1) * 3
local outputIndex = (chunkIndex - 1) * 4

local chunk = bit32.byteswap(buffer.readu32(input, inputIndex))

-- 8 + 24 - (6 * index)
local value1 = bit32.rshift(chunk, 26)
local value2 = bit32.band(bit32.rshift(chunk, 20), 0b111111)
local value3 = bit32.band(bit32.rshift(chunk, 14), 0b111111)
local value4 = bit32.band(bit32.rshift(chunk, 8), 0b111111)

buffer.writeu8(output, outputIndex, buffer.readu8(lookupValueToCharacter, value1))
buffer.writeu8(output, outputIndex + 1, buffer.readu8(lookupValueToCharacter, value2))
buffer.writeu8(output, outputIndex + 2, buffer.readu8(lookupValueToCharacter, value3))
buffer.writeu8(output, outputIndex + 3, buffer.readu8(lookupValueToCharacter, value4))
end

local inputRemainder = inputLength % 3

if inputRemainder == 1 then
local chunk = buffer.readu8(input, inputLength - 1)

local value1 = bit32.rshift(chunk, 2)
local value2 = bit32.band(bit32.lshift(chunk, 4), 0b111111)

buffer.writeu8(output, outputLength - 4, buffer.readu8(lookupValueToCharacter, value1))
buffer.writeu8(output, outputLength - 3, buffer.readu8(lookupValueToCharacter, value2))
buffer.writeu8(output, outputLength - 2, padding)
buffer.writeu8(output, outputLength - 1, padding)
elseif inputRemainder == 2 then
local chunk = bit32.bor(
bit32.lshift(buffer.readu8(input, inputLength - 2), 8),
buffer.readu8(input, inputLength - 1)
)

if remainder == 1 then -- AA==
local chunk = string.byte(input, inputLength)
local value1 = bit32.rshift(chunk, 10)
local value2 = bit32.band(bit32.rshift(chunk, 4), 0b111111)
local value3 = bit32.band(bit32.lshift(chunk, 2), 0b111111)

buffer.writeu8(output, outputLength - 4, buffer.readu8(lookupValueToCharacter, value1))
buffer.writeu8(output, outputLength - 3, buffer.readu8(lookupValueToCharacter, value2))
buffer.writeu8(output, outputLength - 2, buffer.readu8(lookupValueToCharacter, value3))
buffer.writeu8(output, outputLength - 1, padding)
elseif inputRemainder == 0 and inputLength ~= 0 then
local chunk = bit32.bor(
bit32.lshift(buffer.readu8(input, inputLength - 3), 16),
bit32.lshift(buffer.readu8(input, inputLength - 2), 8),
buffer.readu8(input, inputLength - 1)
)

output[outputLength - 3] = lookupValueToASCII[bit32.rshift(chunk, 2)]
output[outputLength - 2] = lookupValueToASCII[bit32.band(bit32.lshift(chunk, 4), 0b111111)]
output[outputLength - 1] = 61
output[outputLength] = 61
elseif remainder == 2 then -- AAA=
local chunk = string.unpack(">H", input, inputLength - 1)
local value1 = bit32.rshift(chunk, 18)
local value2 = bit32.band(bit32.rshift(chunk, 12), 0b111111)
local value3 = bit32.band(bit32.rshift(chunk, 6), 0b111111)
local value4 = bit32.band(chunk, 0b111111)

output[outputLength - 3] = lookupValueToASCII[bit32.rshift(chunk, 10)]
output[outputLength - 2] = lookupValueToASCII[bit32.band(bit32.rshift(chunk, 4), 0b111111)]
output[outputLength - 1] = lookupValueToASCII[bit32.band(bit32.lshift(chunk, 2), 0b111111)]
output[outputLength] = 61
buffer.writeu8(output, outputLength - 4, buffer.readu8(lookupValueToCharacter, value1))
buffer.writeu8(output, outputLength - 3, buffer.readu8(lookupValueToCharacter, value2))
buffer.writeu8(output, outputLength - 2, buffer.readu8(lookupValueToCharacter, value3))
buffer.writeu8(output, outputLength - 1, buffer.readu8(lookupValueToCharacter, value4))
end

return buildStringFromCodes(output)
return output
end

local function decode(input: string): string
local inputLength = #input
local outputLength = math.ceil(inputLength / 4) * 3

local padding = 0
if string.byte(input, inputLength - 1) == 61 then
padding = 2
elseif string.byte(input, inputLength) == 61 then
padding = 1
local function decode(input: buffer): buffer
local inputLength = buffer.len(input)
local inputChunks = math.ceil(inputLength / 4)
-- TODO: Support input without padding
local inputPadding = 0
if inputLength ~= 0 then
if buffer.readu8(input, inputLength - 1) == padding then inputPadding += 1 end
if buffer.readu8(input, inputLength - 2) == padding then inputPadding += 1 end
end

local output = table.create(outputLength - padding, 0)

for chunkIndex = 0, (outputLength / 3) - 1 do
local inputIndex = chunkIndex * 4 + 1
local outputIndex = chunkIndex * 3 + 1

local value1, value2, value3, value4 = string.byte(input, inputIndex, inputIndex + 3)

-- Combine all variables into one 24-bit variable to be split up
local compound = bit32.bor(
bit32.lshift(lookupASCIIToValue[value1], 18),
bit32.lshift(lookupASCIIToValue[value2], 12),
bit32.lshift(lookupASCIIToValue[value3], 6),
lookupASCIIToValue[value4]
local outputLength = inputChunks * 3 - inputPadding
local output = buffer.create(outputLength)

for chunkIndex = 1, inputChunks - 1 do
local inputIndex = (chunkIndex - 1) * 4
local outputIndex = (chunkIndex - 1) * 3

local value1 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, inputIndex))
local value2 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, inputIndex + 1))
local value3 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, inputIndex + 2))
local value4 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, inputIndex + 3))

local chunk = bit32.bor(
bit32.lshift(value1, 18),
bit32.lshift(value2, 12),
bit32.lshift(value3, 6),
value4
)

output[outputIndex] = bit32.rshift(compound, 16)
output[outputIndex + 1] = bit32.band(bit32.rshift(compound, 8), 0b11111111)
output[outputIndex + 2] = bit32.band(compound, 0b11111111)

local character1 = bit32.rshift(chunk, 16)
local character2 = bit32.band(bit32.rshift(chunk, 8), 0b11111111)
local character3 = bit32.band(chunk, 0b11111111)

buffer.writeu8(output, outputIndex, character1)
buffer.writeu8(output, outputIndex + 1, character2)
buffer.writeu8(output, outputIndex + 2, character3)
end

if padding >= 1 then
output[outputLength] = nil

if padding >= 2 then
output[outputLength - 1] = nil

if inputLength ~= 0 then
local lastInputIndex = (inputChunks - 1) * 4
local lastOutputIndex = (inputChunks - 1) * 3

local lastValue1 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, lastInputIndex))
local lastValue2 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, lastInputIndex + 1))
local lastValue3 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, lastInputIndex + 2))
local lastValue4 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, lastInputIndex + 3))

local lastChunk = bit32.bor(
bit32.lshift(lastValue1, 18),
bit32.lshift(lastValue2, 12),
bit32.lshift(lastValue3, 6),
lastValue4
)

if inputPadding <= 2 then
local lastCharacter1 = bit32.rshift(lastChunk, 16)
buffer.writeu8(output, lastOutputIndex, lastCharacter1)

if inputPadding <= 1 then
local lastCharacter2 = bit32.band(bit32.rshift(lastChunk, 8), 0b11111111)
buffer.writeu8(output, lastOutputIndex + 1, lastCharacter2)

if inputPadding == 0 then
local lastCharacter3 = bit32.band(lastChunk, 0b11111111)
buffer.writeu8(output, lastOutputIndex + 2, lastCharacter3)
end
end
end
end

return buildStringFromCodes(output)
return output
end

return {
Expand Down
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ A pretty fast Luau Base64 encoder
Add the following to your `wally.toml` under `[dependencies]`:

```toml
base64 = "reselim/base64@2.0.3"
base64 = "reselim/base64@3.0.0"
```

### Manual
Expand All @@ -21,20 +21,22 @@ base64 = "reselim/[email protected]"
```lua
local Base64 = require(path.to.Base64)

local data = "Hello, world!"
local data = buffer.fromstring("Hello, world!")

local encodedData = Base64.encode(data) -- "SGVsbG8sIHdvcmxkIQ=="
local decodedData = Base64.decode(encodedData) -- "Hello, world!"
local encodedData = Base64.encode(data) -- buffer: "SGVsbG8sIHdvcmxkIQ=="
local decodedData = Base64.decode(encodedData) -- buffer: "Hello, world!"

print(buffer.tostring(decodedData)) -- "Hello, world!"
```

## Benchmarks

Benchmarks ran in Roblox Studio with a payload of **10,000,000** characters running on a **Ryzen 5900X** and **32GB RAM @ 3200MHz**, as of **2023/11/03**
Benchmarks ran in Roblox Studio with a payload of **100,000,000** characters running on a **Ryzen 5900X** and **32GB RAM @ 3200MHz**, as of **2024/01/11**

#### Native mode OFF:
- Encode: 569.976ms (17,544,586/s)
- Decode: 333.244ms (30,008,033/s)
- Encode: 3303.27ms (30,273,037/s)
- Decode: 3747.17ms (26,686,826/s)

#### Native mode ON:
- Encode: 365.399ms (27,367,321/s)
- Decode: 166.880ms (59,923,405/s)
- Encode: 461.23ms (216,813,496/s)
- Decode: 596.37ms (167,680,012/s)
2 changes: 1 addition & 1 deletion wally.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "reselim/base64"
description = "A pretty fast Luau Base64 encoder"
version = "2.0.3"
version = "3.0.0"
registry = "https://github.com/UpliftGames/wally-index"
realm = "shared"
include = ["default.project.json", "Base64.lua"]

0 comments on commit 5928fc1

Please sign in to comment.