-
Notifications
You must be signed in to change notification settings - Fork 4
/
Base64.lua
168 lines (133 loc) · 6.31 KB
/
Base64.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
--!native
--!optimize 2
local lookupValueToCharacter = buffer.create(64)
local lookupCharacterToValue = buffer.create(256)
local alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
local padding = string.byte("=")
for index = 1, 64 do
local value = index - 1
local character = string.byte(alphabet, index)
buffer.writeu8(lookupValueToCharacter, value, character)
buffer.writeu8(lookupCharacterToValue, character, value)
end
local function encode(input: buffer): buffer
local inputLength = buffer.len(input)
local inputChunks = math.ceil(inputLength / 3)
local outputLength = inputChunks * 4
local output = buffer.create(outputLength)
-- Since we use readu32 and chunks are 3 bytes large, we can't read the last chunk here
for chunkIndex = 1, inputChunks - 1 do
local inputIndex = (chunkIndex - 1) * 3
local outputIndex = (chunkIndex - 1) * 4
local chunk = bit32.byteswap(buffer.readu32(input, inputIndex))
-- 8 + 24 - (6 * index)
local value1 = bit32.rshift(chunk, 26)
local value2 = bit32.band(bit32.rshift(chunk, 20), 0b111111)
local value3 = bit32.band(bit32.rshift(chunk, 14), 0b111111)
local value4 = bit32.band(bit32.rshift(chunk, 8), 0b111111)
buffer.writeu8(output, outputIndex, buffer.readu8(lookupValueToCharacter, value1))
buffer.writeu8(output, outputIndex + 1, buffer.readu8(lookupValueToCharacter, value2))
buffer.writeu8(output, outputIndex + 2, buffer.readu8(lookupValueToCharacter, value3))
buffer.writeu8(output, outputIndex + 3, buffer.readu8(lookupValueToCharacter, value4))
end
local inputRemainder = inputLength % 3
if inputRemainder == 1 then
local chunk = buffer.readu8(input, inputLength - 1)
local value1 = bit32.rshift(chunk, 2)
local value2 = bit32.band(bit32.lshift(chunk, 4), 0b111111)
buffer.writeu8(output, outputLength - 4, buffer.readu8(lookupValueToCharacter, value1))
buffer.writeu8(output, outputLength - 3, buffer.readu8(lookupValueToCharacter, value2))
buffer.writeu8(output, outputLength - 2, padding)
buffer.writeu8(output, outputLength - 1, padding)
elseif inputRemainder == 2 then
local chunk = bit32.bor(
bit32.lshift(buffer.readu8(input, inputLength - 2), 8),
buffer.readu8(input, inputLength - 1)
)
local value1 = bit32.rshift(chunk, 10)
local value2 = bit32.band(bit32.rshift(chunk, 4), 0b111111)
local value3 = bit32.band(bit32.lshift(chunk, 2), 0b111111)
buffer.writeu8(output, outputLength - 4, buffer.readu8(lookupValueToCharacter, value1))
buffer.writeu8(output, outputLength - 3, buffer.readu8(lookupValueToCharacter, value2))
buffer.writeu8(output, outputLength - 2, buffer.readu8(lookupValueToCharacter, value3))
buffer.writeu8(output, outputLength - 1, padding)
elseif inputRemainder == 0 and inputLength ~= 0 then
local chunk = bit32.bor(
bit32.lshift(buffer.readu8(input, inputLength - 3), 16),
bit32.lshift(buffer.readu8(input, inputLength - 2), 8),
buffer.readu8(input, inputLength - 1)
)
local value1 = bit32.rshift(chunk, 18)
local value2 = bit32.band(bit32.rshift(chunk, 12), 0b111111)
local value3 = bit32.band(bit32.rshift(chunk, 6), 0b111111)
local value4 = bit32.band(chunk, 0b111111)
buffer.writeu8(output, outputLength - 4, buffer.readu8(lookupValueToCharacter, value1))
buffer.writeu8(output, outputLength - 3, buffer.readu8(lookupValueToCharacter, value2))
buffer.writeu8(output, outputLength - 2, buffer.readu8(lookupValueToCharacter, value3))
buffer.writeu8(output, outputLength - 1, buffer.readu8(lookupValueToCharacter, value4))
end
return output
end
local function decode(input: buffer): buffer
local inputLength = buffer.len(input)
local inputChunks = math.ceil(inputLength / 4)
-- TODO: Support input without padding
local inputPadding = 0
if inputLength ~= 0 then
if buffer.readu8(input, inputLength - 1) == padding then inputPadding += 1 end
if buffer.readu8(input, inputLength - 2) == padding then inputPadding += 1 end
end
local outputLength = inputChunks * 3 - inputPadding
local output = buffer.create(outputLength)
for chunkIndex = 1, inputChunks - 1 do
local inputIndex = (chunkIndex - 1) * 4
local outputIndex = (chunkIndex - 1) * 3
local value1 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, inputIndex))
local value2 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, inputIndex + 1))
local value3 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, inputIndex + 2))
local value4 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, inputIndex + 3))
local chunk = bit32.bor(
bit32.lshift(value1, 18),
bit32.lshift(value2, 12),
bit32.lshift(value3, 6),
value4
)
local character1 = bit32.rshift(chunk, 16)
local character2 = bit32.band(bit32.rshift(chunk, 8), 0b11111111)
local character3 = bit32.band(chunk, 0b11111111)
buffer.writeu8(output, outputIndex, character1)
buffer.writeu8(output, outputIndex + 1, character2)
buffer.writeu8(output, outputIndex + 2, character3)
end
if inputLength ~= 0 then
local lastInputIndex = (inputChunks - 1) * 4
local lastOutputIndex = (inputChunks - 1) * 3
local lastValue1 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, lastInputIndex))
local lastValue2 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, lastInputIndex + 1))
local lastValue3 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, lastInputIndex + 2))
local lastValue4 = buffer.readu8(lookupCharacterToValue, buffer.readu8(input, lastInputIndex + 3))
local lastChunk = bit32.bor(
bit32.lshift(lastValue1, 18),
bit32.lshift(lastValue2, 12),
bit32.lshift(lastValue3, 6),
lastValue4
)
if inputPadding <= 2 then
local lastCharacter1 = bit32.rshift(lastChunk, 16)
buffer.writeu8(output, lastOutputIndex, lastCharacter1)
if inputPadding <= 1 then
local lastCharacter2 = bit32.band(bit32.rshift(lastChunk, 8), 0b11111111)
buffer.writeu8(output, lastOutputIndex + 1, lastCharacter2)
if inputPadding == 0 then
local lastCharacter3 = bit32.band(lastChunk, 0b11111111)
buffer.writeu8(output, lastOutputIndex + 2, lastCharacter3)
end
end
end
end
return output
end
return {
encode = encode,
decode = decode,
}