Files
go-xmlsec/xmlenc/xmlenc.go

212 lines
6.2 KiB
Go

package xmlenc
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"encoding/xml"
"errors"
"fmt"
"hash"
"io"
)
type method struct {
Algorithm string `xml:",attr"`
}
type encryptedData struct {
XMLName string `xml:"http://www.w3.org/2001/04/xmlenc# EncryptedData"`
ID string `xml:"Id,attr"`
Type string `xml:",attr"`
EncryptionMethod method `xml:"EncryptionMethod"`
KeyInfo keyInfo `xml:"http://www.w3.org/2000/09/xmldsig# KeyInfo"`
CipherData *cipherData
}
type keyInfo struct {
XMLName string `xml:"http://www.w3.org/2000/09/xmldsig# KeyInfo"`
EncryptedKey *encryptedKey `xml:"http://www.w3.org/2001/04/xmlenc# EncryptedKey"`
X509Data x509Data `xml:"http://www.w3.org/2000/09/xmldsig# X509Data"`
}
type encryptedKey struct {
XMLName string `xml:"http://www.w3.org/2001/04/xmlenc# EncryptedKey"`
EncryptionMethod *encryptionMethod
KeyInfo *keyInfo
CipherData *cipherData `xml:"http://www.w3.org/2001/04/xmlenc# CipherData"`
}
type encryptionMethod struct {
Algorithm string `xml:",attr"`
DigestMethod method `xml:"http://www.w3.org/2000/09/xmldsig# DigestMethod"`
}
type x509Data struct {
XMLName string `xml:"http://www.w3.org/2000/09/xmldsig# X509Data"`
X509Certificate string
}
type cipherData struct {
XMLName string `xml:"http://www.w3.org/2001/04/xmlenc# CipherData"`
CipherValue string `xml:"CipherValue"`
}
var ErrNoEncryptedDataFound = errors.New("no EncryptedData elements found")
// Decrypt searches the serialized XML document `doc` looking for
// an EncryptedData element. When found, it decrypts the element
// and returns the plaintext of the encrypted section.
//
// Key is a PEM-encoded RSA private key, or a binary TDES key or a
// binary AES key, depending on the encryption type in use.
func Decrypt(key []byte, doc []byte) ([]byte, error) {
decoder := xml.NewDecoder(bytes.NewReader(doc))
for {
t, err := decoder.Token()
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
if startElement, ok := t.(xml.StartElement); ok {
if startElement.Name.Space == "http://www.w3.org/2001/04/xmlenc#" && startElement.Name.Local == "EncryptedData" {
d := encryptedData{}
if err := decoder.DecodeElement(&d, &startElement); err != nil {
return nil, err
}
plaintext, err := decryptEncryptedData(key, &d)
if err != nil {
return nil, err
}
return plaintext, nil
}
}
}
return nil, ErrNoEncryptedDataFound
}
// decryptEncryptedData decrypts the EncryptedData element and returns the
// plaintext.
func decryptEncryptedData(key []byte, d *encryptedData) ([]byte, error) {
if d.KeyInfo.EncryptedKey != nil {
var err error
key, err = decryptEncryptedKey(key, d.KeyInfo.EncryptedKey)
if err != nil {
return nil, err
}
}
iv := []byte{}
ciphertext, err := base64.StdEncoding.DecodeString(d.CipherData.CipherValue)
if err != nil {
return nil, err
}
var blockCipher cipher.Block
switch d.EncryptionMethod.Algorithm {
case "http://www.w3.org/2001/04/xmlenc#tripledes-cbc":
blockCipher, err = des.NewTripleDESCipher(key)
if err != nil {
return nil, err
}
iv = ciphertext[:des.BlockSize]
ciphertext = ciphertext[des.BlockSize:]
case "http://www.w3.org/2001/04/xmlenc#aes128-cbc",
"http://www.w3.org/2001/04/xmlenc#aes192-cbc",
"http://www.w3.org/2001/04/xmlenc#aes256-cbc":
blockCipher, err = aes.NewCipher(key)
if err != nil {
return nil, err
}
iv = ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]
default:
return nil, fmt.Errorf("unsupported encryption method: %s", d.EncryptionMethod.Algorithm)
}
mode := cipher.NewCBCDecrypter(blockCipher, iv)
mode.CryptBlocks(ciphertext, ciphertext)
// I've noticed a trailing 0x01 byte in the plaintext
// which I cannot explain and which breaks things downstream.
// Lacking a better option, we'll strip it here. There are
// probably loads of better ways to handle this, not least of
// which is to figure out where that strange byte is coming
// from.
// TODO(ross): figure out where this comes from
if ciphertext[len(ciphertext)-1] == 0x1 {
ciphertext = ciphertext[:len(ciphertext)-1]
}
return ciphertext, nil
}
// decryptEncryptedKey returns the plaintext version of the EncryptedKey which is
// encrypted using RSA-PKCS1v15 or RSA-OAEP-MGF1P and assuming the `key` is
// a PEM-encoded RSA private key.
func decryptEncryptedKey(key []byte, encryptedKey *encryptedKey) ([]byte, error) {
// All the supported encryption schemes are based on RSA, so `key` must be an
// RSA key. (c.f. http://www.w3.org/TR/2002/REC-xmlenc-core-20021210/Overview.html
// in the "Key Transport" section)
pemBlock, _ := pem.Decode(key)
if pemBlock == nil {
return nil, fmt.Errorf("Cannot parse key as PEM encoded RSA private key")
}
rsaPriv, err := x509.ParsePKCS1PrivateKey(pemBlock.Bytes)
if err != nil {
return nil, err
}
// The only supported/required algorithm is SHA1
// (c.f. http://www.w3.org/TR/2001/PR-xmldsig-core-20010820/ section "Algorithms")
//
// TODO(ross): if RSA-PKCS1v15 is used, do we need to specify the digest algorithm?
var hashFunc hash.Hash
switch encryptedKey.EncryptionMethod.DigestMethod.Algorithm {
case "http://www.w3.org/2000/09/xmldsig#sha1":
hashFunc = sha1.New()
default:
return nil, fmt.Errorf("unsupported digest method: %s",
encryptedKey.EncryptionMethod.DigestMethod.Algorithm)
}
sessionKeyCiphertext, err := base64.StdEncoding.DecodeString(encryptedKey.CipherData.CipherValue)
if err != nil {
return nil, err
}
var sessionKeyPlaintext []byte
switch encryptedKey.EncryptionMethod.Algorithm {
case "http://www.w3.org/2001/04/xmlenc#rsa-1_5":
sessionKeyPlaintext, err = rsa.DecryptPKCS1v15(rand.Reader, rsaPriv,
sessionKeyCiphertext)
if err != nil {
return nil, err
}
case "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p":
sessionKeyPlaintext, err = rsa.DecryptOAEP(hashFunc, rand.Reader,
rsaPriv, sessionKeyCiphertext, nil)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unsupported encryption method: %s",
encryptedKey.EncryptionMethod.Algorithm)
}
return sessionKeyPlaintext, nil
}