swiss hacking challenge 2024 - desperate-intern

Posted on May 1, 2024

Difficulty: medium

Category: crypto

Author: Roxy

We have this new intern, John. Or was it Jack? No one really knows, we don’t see him around much since the accident.

Anyways, he built this very safe cryptosystem for us, but Gerard is vary of his skills and sent me to get you to test it.

There are so many boxes here, make sure you don’t stumble over them! Susan from HR broke her foot last week. Poor Susan. We should get her some flowers, don’t you think?

Files

We get a pretty standard DES implementation except for some suspicious differences:

# S-boxes (Customizable)
# Each S-box is a 4x16 matrix
# S-boxes taken from FIPS 46-3 Appendix 1
S_BOXES = [
        [0, 1, 2, 3, 4, 5, 6, 7],
        [1, 2, 3, 4, 5, 6, 7, 8],
        [2, 3, 4, 5, 6, 7, 8, 9],
        [3, 4, 5, 6, 7, 8, 9, 10],
        [4, 5, 6, 7, 8, 9, 10, 11],
        [5, 6, 7, 8, 9, 10, 11, 12],
        [6, 7, 8, 9, 10, 11, 12, 13],
        [7, 8, 9, 10, 11, 12, 13, 14]

]

# ...
def substitute(expanded_half):
    output = ''
    for i in range(0, len(expanded_half), 6):
        chunk = expanded_half[i:i + 6]
        row = int(chunk[:3], 2)
        col = int(chunk[3:], 2)
        val = S_BOXES[row][col]
        output += format(val, '04b')
    return output

# ...

data = input('Enter the plaintext you want to encrypt \n>')
cipher_text, feistel_output = des_ecb_encrypt(data, key)
print("Ciphertext = ", cipher_text)
print("Feistel output = ", feistel_output)

Exploitation

We can exploit this in the following way:

  • Generate loads of plaintext/ciphertext pairs
  • Because we have the feistel output in combination with the weird s-boxes, brute-force the last round key
  • Brute the 8 bits of the initial round key that got lost during the permutations
  • Decrypt the flag with the leaked round keys
#!/usr/bin/env python3

from pwn import *
from des_server_handout import *
from Crypto.Util.number import long_to_bytes

# === Custom DES functions === #

def generate_round_keys_from_partial(key):
    round_keys = []
    left_half = key[:28]
    right_half = key[28:]
    for round_num, shift_bits in enumerate(shift_table):
        left_half = rotate_left(left_half, shift_bits)
        right_half = rotate_left(right_half, shift_bits)
        round_key = permute(left_half + right_half, PC2)
        round_keys.append(round_key)
    return round_keys

def des_encrypt(plain_text, key):
    cipher_text = ''
    round_keys = generate_round_keys_from_partial(key)

    plain_text = permute(plain_text, IP)
    left_half = plain_text[:32]
    right_half = plain_text[32:]

    for round_key in round_keys[:-1]:
        feistel_output = feistel_network(right_half, round_key)
        new_right_half = xor_strings(left_half, feistel_output[-1])
        left_half = right_half
        right_half = new_right_half
    feistel_output = feistel_network(right_half, round_keys[-1])
    new_right_half = xor_strings(left_half, feistel_output[-1])
    left_half = right_half
    right_half = new_right_half
    cipher_text = permute(right_half + left_half, FP)
    return cipher_text

def rotate_right(key, bits):
    return key[-bits:] + key[:-bits]

def des_decrypt(ciphertext, key):
    plain_text = ''
    round_keys = generate_round_keys_from_partial(key)[::-1]

    ciphertext = permute(ciphertext, IP)
    left_half = ciphertext[:32]
    right_half = ciphertext[32:]

    for round_key in round_keys[:-1]:
        feistel_output = feistel_network(right_half, round_key)
        new_right_half = xor_strings(left_half, feistel_output[-1])
        left_half = right_half
        right_half = new_right_half
    feistel_output = feistel_network(right_half, round_keys[-1])
    new_right_half = xor_strings(left_half, feistel_output[-1])
    left_half = right_half
    right_half = new_right_half

    plain_text = permute(right_half + left_half, FP)

    return plain_text, feistel_output

def des_ecb_decrypt(ciphertext, key):
    plain_text = ''
    feistel_output = []
    for i in range(0, len(ciphertext), 64):
        block = ciphertext[i:i + 64]
        curr_plain_text, curr_feistel_output = des_decrypt(block, key)
        plain_text += curr_plain_text
        feistel_output.append(curr_feistel_output)
    return plain_text

# === Permutation === #

def inverse_permutation_table(table, length=None):
    if length is None:
        length = len(table)
    inverse_table = [0] * length
    for index, value in enumerate(table):
        inverse_table[value-1] = index + 1
    return inverse_table

def permute(block, table):
    if isinstance(block, bytes):
        block = block.decode()
    res = ''
    for i in table:
        res = res + block[i - 1]
    return res


# === Pwntools / Remote connection === #

FP_INV = inverse_permutation_table(FP)
P_INV = inverse_permutation_table(P)

def conn():
    if args.LOCAL:
        r = process(['python3', './des_server_handout.py'])
    else:
        r = remote('your-instance.ctf.m0unt41n.ch', 1337, ssl=True)
    return r

def get_encrypted_flag(r):
    r.sendlineafter(b'>', b'1')
    r.recvuntil(b'Flag =  ')
    flag = r.recvuntil(b'\n').strip()
    r.recvuntil(b'Feistel output =  ')
    feistel_output = eval(r.recvuntil(b'\n').strip())
    return flag, feistel_output

def encrypt_message(r, msg):
    r.sendlineafter(b'>', b'2')
    r.sendlineafter(b'>', msg)
    r.recvuntil(b'Ciphertext =  ')
    ct = r.recvuntil(b'\n').strip()
    r.recvuntil(b'Feistel output =  ')
    feistel_output = eval(r.recvuntil(b'\n').strip())
    return ct, feistel_output


# === Exploitation === #

def get_possible_sbox_inputs(substituted_half):
    # Generate all possible inputs
    substituted_half_int = int(substituted_half, 2)
    possible_inputs = []
    for i in range(0,8):
        for j in range(0,8):
            # S-Box just adds both halves, lol
            if i + j == substituted_half_int:
                possible_inputs.append(format(i, '03b') + format(j, '03b'))
    return possible_inputs

def get_round_key_bits(substituted_half, expanded_half):
    # Generate all possible keys, keep track of the amount of 0/1 per key bit
    kb_0 = [0]*6
    kb_1 = [0]*6

    possible_inputs = get_possible_sbox_inputs(substituted_half)
    for inp in possible_inputs:
        # XOR with all possible sbox values, update key bit probability
        key_bits = xor_strings(inp, expanded_half)
        for i in range(6):
            if key_bits[i] == '1':
                kb_1[i] += 1
            else:
                kb_0[i] += 1
    return kb_0, kb_1

def get_round_key(ct_pairs, feistel):
    # Initialize both bit arrays
    possible_key_bits_1_count = [0]*48
    possible_key_bits_0_count = [0]*48

    # Loop through all key pairs
    for i in range(len(ct_pairs)):
        permuted_half = feistel[i][-1]
        # Undo initial permutation
        substituted_half = permute(permuted_half, P_INV)
        expanded_half = feistel[i][0]

        # split into pairs
        substituted_half_parts = [substituted_half[i:i+4] for i in range(0, len(substituted_half), 4)]
        expanded_half_parts = [expanded_half[i:i+6] for i in range(0, len(expanded_half), 6)]

        # Loop through every pair
        for j in range(len(substituted_half_parts)):
            sbox_output = substituted_half_parts[j]
            kb_0, kb_1 = get_round_key_bits(sbox_output, expanded_half_parts[j])

            key_index = 6*j
            for k in range(6):
                # Update global key bit state
                possible_key_bits_0_count[key_index + k] += kb_0[k]
                possible_key_bits_1_count[key_index + k] += kb_1[k]
    key = ''
    for i in range(48):
        # Decide about key bits based on frequency
        if possible_key_bits_1_count[i] > possible_key_bits_0_count[i]:
            key += '1'
        else:
            key += '0'
    return key

# === Main script === #

def main():
    r = conn()

    # Generate enough plaintext pairs
    PT = cyclic(8*200)
    ct, feistel = encrypt_message(r, PT)

    # split the plaintext pairs into ECB blocks
    ct_pairs = [ct[i:i+64] for i in range(0, len(ct), 64)]
    for i in range(len(ct_pairs)):
        # revert FP
        ct_pairs[i] = permute(ct_pairs[i], FP_INV)

    # get last round key
    key_15 = get_round_key(ct_pairs, feistel)
    success(key_15)

    # get the approximate initial key
    PC2_inv = inverse_permutation_table(PC2, 56)
    key_15 = permute(key_15, PC2_inv)
    left_half, right_half = key_15[:28], key_15[28:]
    # undo all the shifts
    for i in range(16):
        idx = shift_table[i]
        left_half = rotate_right(left_half, idx)
        right_half = rotate_right(right_half, idx)

    # encrypt a known ciphertext to later brute force
    known_ct = encrypt_message(r, b"A"*8)[0][:64].decode()

    # try out all initial keys
    partial_key = list(left_half + right_half)
    # Taken from the PC2 table
    missing_bits = [8, 17, 21, 24, 34, 37, 42, 53]
    for i in range(255):
        binary = format(i, '08b')
        for i in range(8):
            partial_key[missing_bits[i]] = binary[i]
        key = "".join(partial_key)
        ct = des_encrypt(int_to_64b_bitstring(int.from_bytes(b"A"*8)), key)
        if ct == known_ct:
            success(f"Found key: {key}")
            break
    # get and decode the flag
    flag = get_encrypted_flag(r)[0]
    success(long_to_bytes(int(des_ecb_decrypt(flag, key),2)))

if __name__ == "__main__":
    main()

Flag

shc2024{Apes_Dont_Read_Philosophy}

Conclusion

I approached this challenge in the completely wrong way. Instead of just following my initial idea of recovering the last key, I went down rabbitholes of cryptanalysis.