diff --git a/src/daos/AccountDao.ts b/src/daos/AccountDao.ts index d960b957d..410e0a32e 100644 --- a/src/daos/AccountDao.ts +++ b/src/daos/AccountDao.ts @@ -16,6 +16,7 @@ export interface AccountDao { } export const AccountDao = (collection: Collection, encryptionKey: string): AccountDao => { + const encryptWithKey = encrypt(encryptionKey) const modelToEncryptedDocument = asyncPipe(encryptAccount(encryptionKey), modelToDocument) const documentToDecryptedModel = asyncPipe(documentToModel, decryptAccount(encryptionKey)) @@ -44,7 +45,7 @@ export const AccountDao = (collection: Collection, encryptionKey: string): Accou const insertToken = async (filter: Partial, network: Network, token: string): Promise => { const filterDocument = await modelToEncryptedDocument(filter) const array = network === Network.LIVE ? 'apiTokens' : 'testApiTokens' - const apiTokenEncrypted = await Vault.encrypt(token) + const apiTokenEncrypted = encryptWithKey(token) await collection.updateOne(filterDocument, { $push: { [array]: { token: apiTokenEncrypted } }}) } @@ -89,9 +90,14 @@ const modelToDocument = (model: Partial): Partial => { } const encryptAccount = (encryptionKey: string) => async (account: Partial): Promise> => { + const encryptWithKey = encrypt(encryptionKey) + + const encryptApiTokens = async (tokens: ReadonlyArray): Promise> => + tokens.map(tokenObjectToToken).map(encryptWithKey).map(tokenToTokenObject) + const encryptedAccount = { ...account, - privateKey: account.privateKey && encrypt(account.privateKey, encryptionKey), + privateKey: account.privateKey && encryptWithKey(account.privateKey), apiTokens: account.apiTokens && await encryptApiTokens(account.apiTokens), testApiTokens: account.testApiTokens && await encryptApiTokens(account.testApiTokens), } @@ -107,9 +113,14 @@ const encryptAccount = (encryptionKey: string) => async (account: Partial async (account: Partial): Promise> => { + const decrypt = decryptBackwardsCompatible(decryptionKey) + + const decryptApiTokens = async (tokens: ReadonlyArray): Promise> => + Promise.all(tokens.map(tokenObjectToToken).map(decrypt)).then(tokensToTokenObjects) + const decryptedAccount = { ...account, - privateKey: account.privateKey && decrypt(account.privateKey, decryptionKey), + privateKey: account.privateKey && await decrypt(account.privateKey), apiTokens: account.apiTokens && await decryptApiTokens(account.apiTokens), testApiTokens: account.testApiTokens && await decryptApiTokens(account.testApiTokens), } @@ -124,17 +135,14 @@ const decryptAccount = (decryptionKey: string) => async (account: Partial): Promise> => { - const allTokens = tokens.map(({ token }) => token).map(Vault.encrypt, Vault) - const encryptedTokens = await Promise.all(allTokens) - return encryptedTokens.map(token => ({ token })) -} +const tokenToTokenObject = (token: string): Token => ({ token }) +const tokenObjectToToken = ({ token }: Token): string => token +const tokensToTokenObjects = (tokens: ReadonlyArray): ReadonlyArray => tokens.map(tokenToTokenObject) -const decryptApiTokens = async (tokens: ReadonlyArray): Promise> => { - const allTokens = tokens.map(({ token }) => token).map(Vault.decrypt, Vault) - const decryptedTokens = await Promise.all(allTokens) - return decryptedTokens.map(token => ({ token })) -} +const decryptBackwardsCompatible = (key: string) => (plaintext: string) => + plaintext.startsWith('vault') + ? Vault.decrypt(plaintext) + : decrypt(plaintext, key) interface AccountDocument { readonly id?: Buffer | Binary diff --git a/src/helpers/crypto.ts b/src/helpers/crypto.ts index 04c05d88a..8d71c55c5 100644 --- a/src/helpers/crypto.ts +++ b/src/helpers/crypto.ts @@ -3,7 +3,7 @@ import { createCipheriv, createDecipheriv, randomBytes } from 'crypto' const algorithm = 'id-aes256-GCM' const ivLengthInBytes = 96 -export const encrypt = (text: string, key: string): string => { +export const encrypt = (key: string) => (text: string): string => { const iv = randomBytes(ivLengthInBytes) const cipher = createCipheriv(algorithm, Buffer.from(key, 'hex'), iv) const ciphertext = cipher.update(text, 'utf8', 'hex') + cipher.final('hex')