implement native xmlenc decrypt function
This commit is contained in:
192
xmlenc/decrypt.go
Normal file
192
xmlenc/decrypt.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package xmlenc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/des"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ripemd160"
|
||||
)
|
||||
|
||||
var ErrCannotFindEncryptedDataNode = errors.New("cannot find EncryptedData node")
|
||||
|
||||
var ErrPublicKeyMismatch = errors.New("certificate public key does not match provided private key")
|
||||
|
||||
type ErrUnsupportedAlgorithm struct {
|
||||
Algorithm string
|
||||
}
|
||||
|
||||
func (e ErrUnsupportedAlgorithm) Error() string {
|
||||
return fmt.Sprintf("unsupported algorithm: %s", e.Algorithm)
|
||||
}
|
||||
|
||||
func Decrypt(privateKey []byte, doc []byte) ([]byte, error) {
|
||||
decoder := xml.NewDecoder(bytes.NewReader(doc))
|
||||
for {
|
||||
startOffset := decoder.InputOffset()
|
||||
token, err := decoder.Token()
|
||||
if err == io.EOF {
|
||||
return nil, ErrCannotFindEncryptedDataNode
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if startElement, ok := token.(xml.StartElement); ok {
|
||||
if startElement.Name.Space == "http://www.w3.org/2001/04/xmlenc#" && startElement.Name.Local == "EncryptedData" {
|
||||
encryptedData := EncryptedData{}
|
||||
if err := decoder.DecodeElement(&encryptedData, &startElement); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plaintext, err := decrypt(privateKey, encryptedData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
endOffset := decoder.InputOffset()
|
||||
|
||||
rv := append(doc[:startOffset], append(plaintext, doc[endOffset:]...)...)
|
||||
return rv, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decrypt(privateKey []byte, encryptedData EncryptedData) ([]byte, error) {
|
||||
|
||||
var key []byte
|
||||
if encryptedData.KeyInfo.EncryptedKey != nil {
|
||||
var err error
|
||||
key, err = decryptKey(privateKey, *encryptedData.KeyInfo.EncryptedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encryptedData.CipherData.CipherValue.Data)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "base64 decode ciphertext")
|
||||
}
|
||||
|
||||
var block cipher.Block
|
||||
var iv []byte
|
||||
switch *encryptedData.EncryptionMethod.Algorithm {
|
||||
case "http://www.w3.org/2001/04/xmlenc#tripledes-cbc":
|
||||
block, err = des.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "AES init")
|
||||
}
|
||||
iv = ciphertext[:des.BlockSize]
|
||||
ciphertext = ciphertext[des.BlockSize:]
|
||||
case "http://www.w3.org/2001/04/xmlenc#aes128-cbc":
|
||||
fallthrough
|
||||
case "http://www.w3.org/2001/04/xmlenc#aes256-cbc":
|
||||
fallthrough
|
||||
case "http://www.w3.org/2001/04/xmlenc#aes192-cbc":
|
||||
block, err = aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "AES init")
|
||||
}
|
||||
iv = ciphertext[:aes.BlockSize]
|
||||
ciphertext = ciphertext[aes.BlockSize:]
|
||||
default:
|
||||
return nil, ErrUnsupportedAlgorithm{Algorithm: *encryptedData.EncryptionMethod.Algorithm}
|
||||
}
|
||||
|
||||
if len(ciphertext)%aes.BlockSize != 0 {
|
||||
return nil, errors.Wrap(fmt.Errorf("ciphertext is not a multiple of the block size"),
|
||||
"invalid ciphertext")
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
mode.CryptBlocks(ciphertext, ciphertext)
|
||||
|
||||
// strip padding
|
||||
{
|
||||
paddingLen := int(ciphertext[len(ciphertext)-1])
|
||||
ciphertext = ciphertext[:len(ciphertext)-paddingLen]
|
||||
}
|
||||
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
func rsaPublicKeyEquals(a rsa.PublicKey, b rsa.PublicKey) bool {
|
||||
return a.E == b.E && a.N.Cmp(b.N) == 0
|
||||
}
|
||||
|
||||
func decryptKey(privateKey []byte, encryptedKey EncryptedKey) ([]byte, error) {
|
||||
cipherValue, err := base64.StdEncoding.DecodeString(string(encryptedKey.CipherData.CipherValue.Data))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode key base64")
|
||||
}
|
||||
|
||||
// TODO(ross): add support for http://www.w3.org/2001/04/xmlenc#rsa-1_5 once we can
|
||||
// scrounge up some test vectors
|
||||
if *encryptedKey.EncryptionMethod.Algorithm != "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p" {
|
||||
return nil, ErrUnsupportedAlgorithm{Algorithm: *encryptedKey.EncryptionMethod.Algorithm}
|
||||
}
|
||||
|
||||
pemBlock, _ := pem.Decode(privateKey)
|
||||
if pemBlock == nil || pemBlock.Type != "RSA PRIVATE KEY" {
|
||||
return nil, errors.Wrap(fmt.Errorf("invalid private key"), "parse RSA private key")
|
||||
}
|
||||
rsaKey, err := x509.ParsePKCS1PrivateKey(pemBlock.Bytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "x509.ParsePKCS1PrivateKey")
|
||||
}
|
||||
|
||||
{
|
||||
pemBlock, _ := pem.Decode([]byte(fmt.Sprintf(
|
||||
"-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----\n",
|
||||
string(encryptedKey.KeyInfo.X509Data.X509Certificate.Data))))
|
||||
if pemBlock == nil {
|
||||
return nil, errors.New("cannot parse certificate")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(pemBlock.Bytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "x509.ParseCertificate")
|
||||
}
|
||||
if !rsaPublicKeyEquals(*cert.PublicKey.(*rsa.PublicKey), rsaKey.PublicKey) {
|
||||
return nil, ErrPublicKeyMismatch
|
||||
}
|
||||
}
|
||||
|
||||
label := []byte{}
|
||||
if encryptedKey.EncryptionMethod.OAEPparams != nil {
|
||||
label = encryptedKey.EncryptionMethod.OAEPparams.Data
|
||||
}
|
||||
|
||||
var hashMethod hash.Hash
|
||||
switch encryptedKey.EncryptionMethod.DigestMethod.Algorithm {
|
||||
case "http://www.w3.org/2000/09/xmldsig#sha1":
|
||||
hashMethod = sha1.New()
|
||||
case "http://www.w3.org/2001/04/xmlenc#sha256":
|
||||
hashMethod = sha256.New()
|
||||
case "http://www.w3.org/2001/04/xmlenc#sha512":
|
||||
hashMethod = sha512.New()
|
||||
case "http://www.w3.org/2001/04/xmlenc#ripemd160":
|
||||
hashMethod = ripemd160.New()
|
||||
default:
|
||||
return nil, ErrUnsupportedAlgorithm{Algorithm: encryptedKey.EncryptionMethod.DigestMethod.Algorithm}
|
||||
}
|
||||
|
||||
plaintext, err := rsa.DecryptOAEP(hashMethod, rand.Reader, rsaKey, cipherValue, label)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
193
xmlenc/decrypt_realworld_test.go
Normal file
193
xmlenc/decrypt_realworld_test.go
Normal file
File diff suppressed because one or more lines are too long
296
xmlenc/decrypt_test.go
Normal file
296
xmlenc/decrypt_test.go
Normal file
File diff suppressed because one or more lines are too long
10
xmlenc/setup_test.go
Normal file
10
xmlenc/setup_test.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package xmlenc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
// Hook up gocheck into the "go test" runner.
|
||||
func Test(t *testing.T) { TestingT(t) }
|
||||
64
xmlenc/xmlenc.go
Normal file
64
xmlenc/xmlenc.go
Normal file
@@ -0,0 +1,64 @@
|
||||
// Package xmlenc implements xml encrytion natively
|
||||
// (https://www.w3.org/TR/2002/REC-xmlenc-core-20021210/Overview.html)
|
||||
package xmlenc
|
||||
|
||||
import "encoding/xml"
|
||||
|
||||
type EncryptedData struct {
|
||||
XMLName xml.Name `xml:"http://www.w3.org/2001/04/xmlenc# EncryptedData"`
|
||||
ID *string `xml:"Id,attr"`
|
||||
Type *string `xml:",attr"`
|
||||
EncryptionMethod *EncryptionMethod `xml:"http://www.w3.org/2001/04/xmlenc# EncryptionMethod"`
|
||||
KeyInfo *KeyInfo `xml:"http://www.w3.org/2000/09/xmldsig# KeyInfo"`
|
||||
CipherData *CipherData `xml:"http://www.w3.org/2001/04/xmlenc# CipherData"`
|
||||
}
|
||||
|
||||
type EncryptionMethod struct {
|
||||
XMLName xml.Name `xml:"http://www.w3.org/2001/04/xmlenc# EncryptionMethod"`
|
||||
Algorithm *string `xml:",attr"`
|
||||
OAEPparams *OAEPparams `xml:"http://www.w3.org/2001/04/xmlenc# OAEPparams"`
|
||||
DigestMethod *DigestMethod `xml:"http://www.w3.org/2000/09/xmldsig# DigestMethod"`
|
||||
}
|
||||
|
||||
type DigestMethod struct {
|
||||
XMLName xml.Name `xml:"http://www.w3.org/2000/09/xmldsig# DigestMethod"`
|
||||
Algorithm string `xml:",attr"`
|
||||
}
|
||||
|
||||
type OAEPparams struct {
|
||||
XMLName xml.Name `xml:"http://www.w3.org/2001/04/xmlenc# OAEPparams"`
|
||||
Data []byte `xml:",chardata"`
|
||||
}
|
||||
|
||||
type KeyInfo struct {
|
||||
XMLName xml.Name `xml:"http://www.w3.org/2000/09/xmldsig# KeyInfo"`
|
||||
EncryptedKey *EncryptedKey `xml:"http://www.w3.org/2001/04/xmlenc# EncryptedKey"`
|
||||
X509Data *X509Data `xml:"http://www.w3.org/2000/09/xmldsig# X509Data"`
|
||||
}
|
||||
|
||||
type EncryptedKey struct {
|
||||
XMLName xml.Name `xml:"http://www.w3.org/2001/04/xmlenc# EncryptedKey"`
|
||||
EncryptionMethod *EncryptionMethod `xml:"http://www.w3.org/2001/04/xmlenc# EncryptionMethod"`
|
||||
KeyInfo *KeyInfo `xml:"http://www.w3.org/2000/09/xmldsig# KeyInfo"`
|
||||
CipherData *CipherData `xml:"http://www.w3.org/2001/04/xmlenc# CipherData"`
|
||||
}
|
||||
|
||||
type X509Data struct {
|
||||
XMLName xml.Name `xml:"http://www.w3.org/2000/09/xmldsig# X509Data"`
|
||||
X509Certificate *X509Certificate
|
||||
}
|
||||
|
||||
type X509Certificate struct {
|
||||
XMLName xml.Name `xml:"http://www.w3.org/2000/09/xmldsig# X509Certificate"`
|
||||
Data []byte `xml:",chardata"`
|
||||
}
|
||||
|
||||
type CipherData struct {
|
||||
XMLName xml.Name `xml:"http://www.w3.org/2001/04/xmlenc# CipherData"`
|
||||
CipherValue *CipherValue
|
||||
}
|
||||
|
||||
type CipherValue struct {
|
||||
XMLName xml.Name `xml:"http://www.w3.org/2001/04/xmlenc# CipherValue"`
|
||||
Data string `xml:",chardata"`
|
||||
}
|
||||
Reference in New Issue
Block a user