Skip to content

Commit

Permalink
Use aes256 when payload exceeds 256 bytes
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Couture-Beil <[email protected]>
  • Loading branch information
alexcb committed Jun 14, 2021
1 parent 953febc commit 0113690
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 16 deletions.
Binary file modified build/darwin/amd64/secretshare
Binary file not shown.
159 changes: 143 additions & 16 deletions cmd/secretshare/main.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package main

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/binary"
"encoding/pem"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -41,8 +45,87 @@ func generateKey() (string, string, error) {
return pubKeyStr, privKeyStr, nil
}

func encrypt(msg, publicKey string) (string, error) {
parsed, _, _, _, err := ssh.ParseAuthorizedKey([]byte(publicKey))
// encryptAES256 returns a random passphrase and corresponding bytes encrypted with it
func encryptAES256(data []byte) ([]byte, []byte, error) {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, nil, err
}

n := len(data)
buf := new(bytes.Buffer)
if err := binary.Write(buf, binary.LittleEndian, uint64(n)); err != nil {
return nil, nil, err
}
if _, err := buf.Write(data); err != nil {
return nil, nil, err
}

paddingN := aes.BlockSize - (buf.Len() % aes.BlockSize)
if paddingN > 0 {
padding := make([]byte, paddingN)
if _, err := rand.Read(padding); err != nil {
return nil, nil, err
}
if _, err := buf.Write(padding); err != nil {
return nil, nil, err
}
}
plaintext := buf.Bytes()

block, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}

ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
if _, err := rand.Read(iv); err != nil {
return nil, nil, err
}

mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(ciphertext[aes.BlockSize:], plaintext)

return key, ciphertext, nil
}

func decryptAES(key, ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}

if len(ciphertext) < aes.BlockSize {
panic("ciphertext too short")
}
iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]

if len(ciphertext)%aes.BlockSize != 0 {
panic("ciphertext is not a multiple of the block size")
}

mode := cipher.NewCBCDecrypter(block, iv)

// works inplace when both args are the same
mode.CryptBlocks(ciphertext, ciphertext)

buf := bytes.NewReader(ciphertext)
var n uint64
if err = binary.Read(buf, binary.LittleEndian, &n); err != nil {
return nil, err
}
payload := make([]byte, n)
if _, err = buf.Read(payload); err != nil {
return nil, err
}

return payload, nil
}

func encrypt(msg, publicKey []byte) (string, error) {
parsed, _, _, _, err := ssh.ParseAuthorizedKey(publicKey)
if err != nil {
return "", err
}
Expand All @@ -56,35 +139,79 @@ func encrypt(msg, publicKey string) (string, error) {
// Finally, we can convert back to an *rsa.PublicKey
pub := pubCrypto.(*rsa.PublicKey)

if len(msg) <= 256 {
// msg is small enough to only use OAEP encryption; this will result in less bytes to transfer.
encryptedBytes, err := rsa.EncryptOAEP(
sha256.New(),
rand.Reader,
pub,
msg,
nil)
if err != nil {
return "", err
}
if len(encryptedBytes) != 256 {
panic(len(encryptedBytes))
}
return base64.StdEncoding.EncodeToString(encryptedBytes), nil
}

// otherwise, encrypt using AES256

key, ciphertext, err := encryptAES256(msg)
if err != nil {
return "", err
}

encryptedBytes, err := rsa.EncryptOAEP(
sha256.New(),
rand.Reader,
pub,
[]byte(msg),
key,
nil)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(encryptedBytes), nil
if len(encryptedBytes) != 256 {
panic(len(encryptedBytes))
}
return base64.StdEncoding.EncodeToString(append(encryptedBytes, ciphertext...)), nil
}

func decrypt(data, priv string) (string, error) {
func decrypt(data, priv string) ([]byte, error) {
data2, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
return nil, err
}

if len(data2) < 256 {
return nil, fmt.Errorf("not enough data to decrypt")
}

block, _ := pem.Decode([]byte(priv))
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return "", err
return nil, err
}

decrypted, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, key, data2, nil)
oaepData := data2[:256]
aesData := data2[256:]
payload, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, key, oaepData, nil)
if err != nil {
return "", err
return nil, err
}
return string(decrypted), nil

if len(aesData) == 0 {
return payload, nil
}

decryptedAESKey := payload
decrypted, err := decryptAES(decryptedAESKey, aesData)
if err != nil {
return nil, err
}

return decrypted, nil
}

func test() {
Expand All @@ -94,8 +221,8 @@ func test() {
}
pub = strings.TrimPrefix(pub, "ssh-rsa ")

data := "hello test"
encrypted, err := encrypt(data, "ssh-rsa "+pub)
data := []byte("hello test")
encrypted, err := encrypt(data, []byte("ssh-rsa "+pub))
if err != nil {
panic(err)
}
Expand All @@ -105,7 +232,7 @@ func test() {
panic(err)
}

if data != data2 {
if !bytes.Equal(data, data2) {
panic("missmatch")
}
}
Expand Down Expand Up @@ -178,21 +305,21 @@ func main() {

data, err := ioutil.ReadAll(os.Stdin)
if err != nil {
fmt.Fprintf(os.Stderr, "failed while reading from stdin: %s", err.Error())
fmt.Fprintf(os.Stderr, "failed while reading from stdin: %s\n", err.Error())
os.Exit(1)
}

if arg == "decrypt" {
data2, err := decrypt(string(data), priv)
if err != nil {
fmt.Fprintf(os.Stderr, "failed while decrypting: %s", err.Error())
fmt.Fprintf(os.Stderr, "failed while decrypting: %s\n", err.Error())
os.Exit(1)
}
fmt.Printf("%s", data2)
return
}

encrypted, err := encrypt(string(data), "ssh-rsa "+arg)
encrypted, err := encrypt(data, []byte("ssh-rsa "+arg))
if err != nil {
fmt.Fprintf(os.Stderr, "failed while encrypting: %s", err.Error())
os.Exit(1)
Expand Down

0 comments on commit 0113690

Please sign in to comment.