/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>

#include "pk11pub.h"
#include "secerr.h"
#include "nss.h"

static SECStatus
hex_to_byteval(const char *c2, unsigned char *byteval)
{
    int i;
    unsigned char offset;
    *byteval = 0;
    for (i=0; i<2; i++) {
	if (c2[i] >= '0' && c2[i] <= '9') {
	    offset = c2[i] - '0';
	    *byteval |= offset << 4*(1-i);
	} else if (c2[i] >= 'a' && c2[i] <= 'f') {
	    offset = c2[i] - 'a';
	    *byteval |= (offset + 10) << 4*(1-i);
	} else if (c2[i] >= 'A' && c2[i] <= 'F') {
	    offset = c2[i] - 'A';
	    *byteval |= (offset + 10) << 4*(1-i);
	} else {
	    return SECFailure;
	}
    }
    return SECSuccess;
}

static SECStatus
aes_encrypt_buf(
    const unsigned char *key, unsigned int keysize,
    const unsigned char *iv, unsigned int ivsize,
    unsigned char *output, unsigned int *outputlen, unsigned int maxoutputlen,
    const unsigned char *input, unsigned int inputlen,
    const unsigned char *aad, unsigned int aadlen, unsigned int tagsize)
{
    SECStatus rv = SECFailure;
    SECItem key_item;
    PK11SlotInfo* slot = NULL;
    PK11SymKey *symKey = NULL;
    CK_GCM_PARAMS gcm_params;
    SECItem param;

    /* Import key into NSS. */
    key_item.type = siBuffer;
    key_item.data = (unsigned char *) key;  /* const cast */
    key_item.len = keysize;
    slot = PK11_GetInternalSlot();
    symKey = PK11_ImportSymKey(slot, CKM_AES_GCM, PK11_OriginUnwrap,
			       CKA_ENCRYPT, &key_item, NULL);
    PK11_FreeSlot(slot);
    slot = NULL;
    if (!symKey) {
	fprintf(stderr, "PK11_ImportSymKey failed\n");
	goto loser;
    }

    gcm_params.pIv = (unsigned char *) iv;  /* const cast */
    gcm_params.ulIvLen = ivsize;
    gcm_params.pAAD = (unsigned char *) aad;  /* const cast */
    gcm_params.ulAADLen = aadlen;
    gcm_params.ulTagBits = tagsize * 8;

    param.type = siBuffer;
    param.data = (unsigned char *) &gcm_params;
    param.len = sizeof(gcm_params);

    if (PK11_Encrypt(symKey, CKM_AES_GCM, &param,
		     output, outputlen, maxoutputlen,
		     input, inputlen) != SECSuccess) {
	fprintf(stderr, "PK11_Encrypt failed\n");
	goto loser;
    }

    rv = SECSuccess;

loser:
    if (symKey != NULL) {
	PK11_FreeSymKey(symKey);
    }
    return rv;
}

static SECStatus
aes_decrypt_buf(
    const unsigned char *key, unsigned int keysize,
    const unsigned char *iv, unsigned int ivsize,
    unsigned char *output, unsigned int *outputlen, unsigned int maxoutputlen,
    const unsigned char *input, unsigned int inputlen,
    const unsigned char *aad, unsigned int aadlen,
    const unsigned char *tag, unsigned int tagsize)
{
    SECStatus rv = SECFailure;
    unsigned char concatenated[11*16];     /* 1 to 11 blocks */
    SECItem key_item;
    PK11SlotInfo *slot = NULL;
    PK11SymKey *symKey = NULL;
    CK_GCM_PARAMS gcm_params;
    SECItem param;

    if (inputlen + tagsize > sizeof(concatenated)) {
	fprintf(stderr, "aes_decrypt_buf: local buffer too small\n");
	goto loser;
    }
    memcpy(concatenated, input, inputlen);
    memcpy(concatenated + inputlen, tag, tagsize);

    /* Import key into NSS. */
    key_item.type = siBuffer;
    key_item.data = (unsigned char *) key;  /* const cast */
    key_item.len = keysize;
    slot = PK11_GetInternalSlot();
    symKey = PK11_ImportSymKey(slot, CKM_AES_GCM, PK11_OriginUnwrap,
			       CKA_DECRYPT, &key_item, NULL);
    PK11_FreeSlot(slot);
    slot = NULL;
    if (!symKey) {
	fprintf(stderr, "PK11_ImportSymKey failed\n");
	goto loser;
    }

    gcm_params.pIv = (unsigned char *) iv;
    gcm_params.ulIvLen = ivsize;
    gcm_params.pAAD = (unsigned char *) aad;
    gcm_params.ulAADLen = aadlen;
    gcm_params.ulTagBits = tagsize * 8;

    param.type = siBuffer;
    param.data = (unsigned char *) &gcm_params;
    param.len = sizeof(gcm_params);

    if (PK11_Decrypt(symKey, CKM_AES_GCM, &param,
		     output, outputlen, maxoutputlen,
		     concatenated, inputlen + tagsize) != SECSuccess) {
	goto loser;
    }

    rv = SECSuccess;

loser:
    if (symKey != NULL) {
	PK11_FreeSymKey(symKey);
    }
    return rv;
}

/*
 * Perform the AES Known Answer Test (KAT) in Galois Counter Mode (GCM).
 *
 * respfn is the pathname of the RESPONSE file.
 */
static void
aes_gcm_kat(const char *respfn)
{
    char buf[512];      /* holds one line from the input REQUEST file.
                         * needs to be large enough to hold the longest
                         * line "CIPHERTEXT = <320 hex digits>\n".
                         */
    FILE *aesresp;      /* input stream from the RESPONSE file */
    int i, j;
    unsigned int test_group = 0;
    unsigned int num_tests;
    PRBool is_encrypt;
    unsigned char key[32];              /* 128, 192, or 256 bits */
    unsigned int keysize;
    unsigned char iv[10*16];            /* 1 to 10 blocks */
    unsigned int ivsize;
    unsigned char plaintext[10*16];     /* 1 to 10 blocks */
    unsigned int plaintextlen = 0;
    unsigned char aad[10*16];           /* 1 to 10 blocks */
    unsigned int aadlen = 0;
    unsigned char ciphertext[10*16];    /* 1 to 10 blocks */
    unsigned int ciphertextlen;
    unsigned char tag[16];
    unsigned int tagsize;
    unsigned char output[10*16];         /* 1 to 10 blocks */
    unsigned int outputlen;

    unsigned int expected_keylen = 0;
    unsigned int expected_ivlen = 0;
    unsigned int expected_ptlen = 0;
    unsigned int expected_aadlen = 0;
    unsigned int expected_taglen = 0;
    SECStatus rv;

    if (strstr(respfn, "Encrypt") != NULL) {
	is_encrypt = PR_TRUE;
    } else if (strstr(respfn, "Decrypt") != NULL) {
	is_encrypt = PR_FALSE;
    } else {
	fprintf(stderr, "Input file name must contain Encrypt or Decrypt\n");
	exit(1);
    }
    aesresp = fopen(respfn, "r");
    if (aesresp == NULL) {
	fprintf(stderr, "Cannot open input file %s\n", respfn);
	exit(1);
    }
    while (fgets(buf, sizeof buf, aesresp) != NULL) {
	/* a comment or blank line */
	if (buf[0] == '#' || buf[0] == '\n') {
	    continue;
	}
	/* [Keylen = ...], [IVlen = ...], etc. */
	if (buf[0] == '[') {
	    if (strncmp(&buf[1], "Keylen = ", 9) == 0) {
		expected_keylen = atoi(&buf[10]);
	    } else if (strncmp(&buf[1], "IVlen = ", 8) == 0) {
		expected_ivlen = atoi(&buf[9]);
	    } else if (strncmp(&buf[1], "PTlen = ", 8) == 0) {
		expected_ptlen = atoi(&buf[9]);
	    } else if (strncmp(&buf[1], "AADlen = ", 9) == 0) {
		expected_aadlen = atoi(&buf[10]);
	    } else if (strncmp(&buf[1], "Taglen = ", 9) == 0) {
		expected_taglen = atoi(&buf[10]);

		test_group++;
		if (test_group > 1) {
		    /* Report num_tests for the previous test group. */
		    printf("%u tests\n", num_tests);
		}
		num_tests = 0;
		printf("Keylen = %u, IVlen = %u, PTlen = %u, AADlen = %u, "
		       "Taglen = %u: ", expected_keylen, expected_ivlen,
		        expected_ptlen, expected_aadlen, expected_taglen);
		/* Convert lengths in bits to lengths in bytes. */
		PORT_Assert(expected_keylen % 8 == 0);
		expected_keylen /= 8;
		PORT_Assert(expected_ivlen % 8 == 0);
		expected_ivlen /= 8;
		PORT_Assert(expected_ptlen % 8 == 0);
		expected_ptlen /= 8;
		PORT_Assert(expected_aadlen % 8 == 0);
		expected_aadlen /= 8;
		PORT_Assert(expected_taglen % 8 == 0);
		expected_taglen /= 8;
	    } else {
		fprintf(stderr, "Unexpected input line: %s\n", buf);
		exit(1);
	    }
	    continue;
	}
	/* "Count = x" begins a new data set */
	if (strncmp(buf, "Count", 5) == 0) {
	    /* zeroize the variables for the test with this data set */
	    memset(key, 0, sizeof key);
	    keysize = 0;
	    memset(iv, 0, sizeof iv);
	    ivsize = 0;
	    memset(plaintext, 0, sizeof plaintext);
	    plaintextlen = 0;
	    memset(aad, 0, sizeof aad);
	    aadlen = 0;
	    memset(ciphertext, 0, sizeof ciphertext);
	    ciphertextlen = 0;
	    memset(output, 0, sizeof output);
	    outputlen = 0;
	    num_tests++;
	    continue;
	}
	/* Key = ... */
	if (strncmp(buf, "Key", 3) == 0) {
	    i = 3;
	    while (isspace(buf[i]) || buf[i] == '=') {
		i++;
	    }
	    for (j=0; isxdigit(buf[i]); i+=2,j++) {
		hex_to_byteval(&buf[i], &key[j]);
	    }
	    keysize = j;
	    if (keysize != expected_keylen) {
		fprintf(stderr, "Unexpected key length: %u vs. %u\n",
			keysize, expected_keylen);
		exit(1);
	    }
	    continue;
	}
	/* IV = ... */
	if (strncmp(buf, "IV", 2) == 0) {
	    i = 2;
	    while (isspace(buf[i]) || buf[i] == '=') {
		i++;
	    }
	    for (j=0; isxdigit(buf[i]); i+=2,j++) {
		hex_to_byteval(&buf[i], &iv[j]);
	    }
	    ivsize = j;
	    if (ivsize != expected_ivlen) {
		fprintf(stderr, "Unexpected IV length: %u vs. %u\n",
			ivsize, expected_ivlen);
		exit(1);
	    }
	    continue;
	}
	/* PT = ... */
	if (strncmp(buf, "PT", 2) == 0) {
	    i = 2;
	    while (isspace(buf[i]) || buf[i] == '=') {
		i++;
	    }
	    for (j=0; isxdigit(buf[i]); i+=2,j++) {
		hex_to_byteval(&buf[i], &plaintext[j]);
	    }
	    plaintextlen = j;
	    if (plaintextlen != expected_ptlen) {
		fprintf(stderr, "Unexpected PT length: %u vs. %u\n",
			plaintextlen, expected_ptlen);
		exit(1);
	    }

	    if (!is_encrypt) {
		rv = aes_decrypt_buf(key, keysize, iv, ivsize,
		    output, &outputlen, sizeof output,
		    ciphertext, ciphertextlen, aad, aadlen, tag, tagsize);
		if (rv != SECSuccess) {
		    fprintf(stderr, "aes_decrypt_buf failed\n");
		    goto loser;
		}
		if (outputlen != plaintextlen) {
		    fprintf(stderr, "aes_decrypt_buf: wrong output size\n");
		    goto loser;
		}
		if (memcmp(output, plaintext, plaintextlen) != 0) {
		    fprintf(stderr, "aes_decrypt_buf: wrong plaintext\n");
		    goto loser;
		}
	    }
	    continue;
	}
	/* FAIL */
	if (strncmp(buf, "FAIL", 4) == 0) {
	    plaintextlen = 0;

	    PORT_Assert(!is_encrypt);
	    rv = aes_decrypt_buf(key, keysize, iv, ivsize,
		output, &outputlen, sizeof output,
		ciphertext, ciphertextlen, aad, aadlen, tag, tagsize);
	    if (rv != SECFailure) {
		fprintf(stderr, "aes_decrypt_buf succeeded unexpectedly\n");
		goto loser;
	    }
	    if (PORT_GetError() != SEC_ERROR_BAD_DATA) {
		fprintf(stderr, "aes_decrypt_buf failed with incorrect "
			"error code\n");
		goto loser;
	    }
	    continue;
	}
	/* AAD = ... */
	if (strncmp(buf, "AAD", 3) == 0) {
	    i = 3;
	    while (isspace(buf[i]) || buf[i] == '=') {
		i++;
	    }
	    for (j=0; isxdigit(buf[i]); i+=2,j++) {
		hex_to_byteval(&buf[i], &aad[j]);
	    }
	    aadlen = j;
	    if (aadlen != expected_aadlen) {
		fprintf(stderr, "Unexpected AAD length: %u vs. %u\n",
			aadlen, expected_aadlen);
		exit(1);
	    }
	    continue;
	}
	/* CT = ... */
	if (strncmp(buf, "CT", 2) == 0) {
	    i = 2;
	    while (isspace(buf[i]) || buf[i] == '=') {
		i++;
	    }
	    for (j=0; isxdigit(buf[i]); i+=2,j++) {
		hex_to_byteval(&buf[i], &ciphertext[j]);
	    }
	    ciphertextlen = j;
	    if (ciphertextlen != expected_ptlen) {
		fprintf(stderr, "Unexpected CT length: %u vs. %u\n",
			ciphertextlen, expected_ptlen);
		exit(1);
	    }
	    continue;
	}
	/* Tag = ... */
	if (strncmp(buf, "Tag", 3) == 0) {
	    i = 3;
	    while (isspace(buf[i]) || buf[i] == '=') {
		i++;
	    }
	    for (j=0; isxdigit(buf[i]); i+=2,j++) {
		hex_to_byteval(&buf[i], &tag[j]);
	    }
	    tagsize = j;
	    if (tagsize != expected_taglen) {
		fprintf(stderr, "Unexpected tag length: %u vs. %u\n",
			tagsize, expected_taglen);
		exit(1);
	    }

	    if (is_encrypt) {
		rv = aes_encrypt_buf(key, keysize, iv, ivsize,
		    output, &outputlen, sizeof output,
		    plaintext, plaintextlen, aad, aadlen, tagsize);
		if (rv != SECSuccess) {
		    fprintf(stderr, "aes_encrypt_buf failed\n");
		    goto loser;
		}
		if (outputlen != plaintextlen + tagsize) {
		    fprintf(stderr, "aes_encrypt_buf: wrong output size\n");
		    goto loser;
		}
		if (memcmp(output, ciphertext, plaintextlen) != 0) {
		    fprintf(stderr, "aes_encrypt_buf: wrong ciphertext\n");
		    goto loser;
		}
		if (memcmp(output + plaintextlen, tag, tagsize) != 0) {
		    fprintf(stderr, "aes_encrypt_buf: wrong tag\n");
		    goto loser;
		}
	    }
	    continue;
	}
    }
    /* Report num_tests for the last test group. */
    printf("%u tests\n", num_tests);
    printf("%u test groups\n", test_group);
    printf("PASS\n");
loser:
    fclose(aesresp);
}

int main(int argc, char **argv)
{
    if (argc < 2) exit(1);

    NSS_NoDB_Init(NULL);

    /*************/
    /*   AES     */
    /*************/
    if (strcmp(argv[1], "aes") == 0) {
	/* argv[2]=kat argv[3]=gcm argv[4]=<test name>.rsp */
	if (strcmp(argv[2], "kat") == 0) {
	    /* Known Answer Test (KAT) */
	    aes_gcm_kat(argv[4]);
	}
    }

    NSS_Shutdown();
    return 0;
}
