diff --git a/libgm/client.go b/libgm/client.go index 34e52fb..3aaed8d 100644 --- a/libgm/client.go +++ b/libgm/client.go @@ -223,7 +223,7 @@ func (c *Client) DownloadMedia(mediaID string, key []byte) ([]byte, error) { if err != nil { return nil, err } - cryptor, err := crypto.NewImageCryptor(key) + cryptor, err := crypto.NewAESGCMHelper(key) if err != nil { return nil, err } diff --git a/libgm/crypto/aesgcm.go b/libgm/crypto/aesgcm.go index 76c39b9..e41fb1e 100644 --- a/libgm/crypto/aesgcm.go +++ b/libgm/crypto/aesgcm.go @@ -4,23 +4,21 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "encoding/binary" "fmt" "math" ) -type ImageCryptor struct { +type AESGCMHelper struct { key []byte + gcm cipher.AEAD } -func NewImageCryptor(key []byte) (*ImageCryptor, error) { +func NewAESGCMHelper(key []byte) (*AESGCMHelper, error) { if len(key) != 32 { return nil, fmt.Errorf("unsupported AES key length (got=%d expected=32)", len(key)) } - return &ImageCryptor{key: key}, nil -} - -func (ic *ImageCryptor) Encrypt(imageBytes []byte, aad []byte) ([]byte, error) { - block, err := aes.NewCipher(ic.key) + block, err := aes.NewCipher(key) if err != nil { return nil, err } @@ -30,34 +28,29 @@ func (ic *ImageCryptor) Encrypt(imageBytes []byte, aad []byte) ([]byte, error) { return nil, err } - nonce := make([]byte, gcm.NonceSize()) - _, err = rand.Read(nonce) - if err != nil { - return nil, err - } - - ciphertext := gcm.Seal(nonce, nonce, imageBytes, aad) - return ciphertext, nil + return &AESGCMHelper{key: key, gcm: gcm}, nil } -func (ic *ImageCryptor) Decrypt(iv []byte, data []byte, aad []byte) ([]byte, error) { - block, err := aes.NewCipher(ic.key) +func (c *AESGCMHelper) encryptChunk(data []byte, aad []byte) []byte { + nonce := make([]byte, c.gcm.NonceSize(), c.gcm.NonceSize()+len(data)) + _, err := rand.Read(nonce) if err != nil { - return nil, err + panic(fmt.Errorf("out of randomness: %w", err)) } - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } + // Pass nonce as the dest, so we have it pre-appended to the output + return c.gcm.Seal(nonce, nonce, data, aad) +} - if len(data) < gcm.NonceSize() { +func (c *AESGCMHelper) decryptChunk(data []byte, aad []byte) ([]byte, error) { + if len(data) < c.gcm.NonceSize() { return nil, fmt.Errorf("invalid encrypted data length (got=%d)", len(data)) } - ciphertext := data[gcm.NonceSize():] + nonce := data[:c.gcm.NonceSize()] + ciphertext := data[c.gcm.NonceSize():] - decrypted, err := gcm.Open(nil, iv, ciphertext, aad) + decrypted, err := c.gcm.Open(nil, nonce, ciphertext, aad) if err != nil { return nil, err } @@ -65,56 +58,38 @@ func (ic *ImageCryptor) Decrypt(iv []byte, data []byte, aad []byte) ([]byte, err return decrypted, nil } -func (ic *ImageCryptor) EncryptData(data []byte) ([]byte, error) { - rawChunkSize := 1 << 15 - chunkSize := rawChunkSize - 28 - var tasks []chan []byte - chunkIndex := 0 +const outgoingRawChunkSize = 1 << 15 +func (c *AESGCMHelper) EncryptData(data []byte) ([]byte, error) { + chunkOverhead := c.gcm.NonceSize() + c.gcm.Overhead() + chunkSize := outgoingRawChunkSize - chunkOverhead + + chunkCount := int(math.Ceil(float64(len(data)) / float64(chunkSize))) + encrypted := make([]byte, 2, 2+len(data)+28*chunkCount) + encrypted[0] = 0 + encrypted[1] = byte(math.Log2(float64(outgoingRawChunkSize))) + + var chunkIndex uint32 for i := 0; i < len(data); i += chunkSize { - if i+chunkSize > len(data) { + isLastChunk := false + if i+chunkSize >= len(data) { chunkSize = len(data) - i + isLastChunk = true } chunk := make([]byte, chunkSize) copy(chunk, data[i:i+chunkSize]) - aad := ic.calculateAAD(chunkIndex, i+chunkSize, len(data)) - tasks = append(tasks, make(chan []byte)) - go func(chunk, aad []byte, task chan []byte) { - encrypted, err := ic.Encrypt(chunk, aad) - if err != nil { - fmt.Println(err) - task <- nil - } else { - task <- encrypted - } - }(chunk, aad, tasks[chunkIndex]) - + aad := c.calculateAAD(chunkIndex, isLastChunk) + encrypted = append(encrypted, c.encryptChunk(data[i:i+chunkSize], aad)...) chunkIndex++ } - var result [][]byte - for _, task := range tasks { - encrypted := <-task - if encrypted == nil { - continue - } - result = append(result, encrypted) - } - - var concatted []byte - for _, r := range result { - concatted = append(concatted, r...) - } - - encryptedHeader := []byte{0, byte(math.Log2(float64(rawChunkSize)))} - - return append(encryptedHeader, concatted...), nil + return encrypted, nil } -func (ic *ImageCryptor) DecryptData(encryptedData []byte) ([]byte, error) { - if len(encryptedData) == 0 || len(ic.key) != 32 { +func (c *AESGCMHelper) DecryptData(encryptedData []byte) ([]byte, error) { + if len(encryptedData) == 0 || len(c.key) != 32 { return encryptedData, nil } if encryptedData[0] != 0 { @@ -124,63 +99,37 @@ func (ic *ImageCryptor) DecryptData(encryptedData []byte) ([]byte, error) { chunkSize := 1 << encryptedData[1] encryptedData = encryptedData[2:] - var tasks []chan []byte - chunkIndex := 0 + var chunkIndex uint32 + chunkCount := int(math.Ceil(float64(len(encryptedData)) / float64(chunkSize))) + decryptedData := make([]byte, 0, len(encryptedData)-28*chunkCount) for i := 0; i < len(encryptedData); i += chunkSize { - if i+chunkSize > len(encryptedData) { + isLastChunk := false + if i+chunkSize >= len(encryptedData) { chunkSize = len(encryptedData) - i + isLastChunk = true } chunk := make([]byte, chunkSize) copy(chunk, encryptedData[i:i+chunkSize]) - iv := chunk[:12] - aad := ic.calculateAAD(chunkIndex, i+chunkSize, len(encryptedData)) - tasks = append(tasks, make(chan []byte)) - go func(iv, chunk, aad []byte, task chan []byte) { - decrypted, err := ic.Decrypt(iv, chunk, aad) - if err != nil { - fmt.Println(err) - task <- nil - } else { - task <- decrypted - } - }(iv, chunk, aad, tasks[chunkIndex]) - + aad := c.calculateAAD(chunkIndex, isLastChunk) + decryptedChunk, err := c.decryptChunk(chunk, aad) + if err != nil { + return nil, fmt.Errorf("failed to decrypt chunk #%d: %w", chunkIndex+1, err) + } + decryptedData = append(decryptedData, decryptedChunk...) chunkIndex++ } - var result [][]byte - for _, task := range tasks { - decrypted := <-task - if decrypted == nil { - continue - } - result = append(result, decrypted) - } - - var concatted []byte - for _, r := range result { - concatted = append(concatted, r...) - } - - return concatted, nil + return decryptedData, nil } -func (ic *ImageCryptor) calculateAAD(index, end, total int) []byte { +func (c *AESGCMHelper) calculateAAD(index uint32, isLastChunk bool) []byte { aad := make([]byte, 5) - - i := 4 - for index > 0 { - aad[i] = byte(index % 256) - index = index / 256 - i-- - } - - if end >= total { + binary.BigEndian.PutUint32(aad[1:5], index) + if isLastChunk { aad[0] = 1 } - return aad } diff --git a/libgm/upload.go b/libgm/upload.go index c138feb..11bc5ca 100644 --- a/libgm/upload.go +++ b/libgm/upload.go @@ -93,7 +93,7 @@ func (c *Client) UploadMedia(data []byte, fileName, mime string) (*gmproto.Media mediaType = MimeToMediaType[strings.Split(mime, "/")[0]] } decryptionKey := crypto.GenerateKey(32) - cryptor, err := crypto.NewImageCryptor(decryptionKey) + cryptor, err := crypto.NewAESGCMHelper(decryptionKey) if err != nil { return nil, err }