swiss hacking challenge 2024 - desperate-intern
Difficulty: medium
Category: crypto
Author: Roxy
We have this new intern, John. Or was it Jack? No one really knows, we don’t see him around much since the accident.
Anyways, he built this very safe cryptosystem for us, but Gerard is vary of his skills and sent me to get you to test it.
There are so many boxes here, make sure you don’t stumble over them! Susan from HR broke her foot last week. Poor Susan. We should get her some flowers, don’t you think?
Files
We get a pretty standard DES implementation except for some suspicious differences:
# S-boxes (Customizable)
# Each S-box is a 4x16 matrix
# S-boxes taken from FIPS 46-3 Appendix 1
S_BOXES = [
[0, 1, 2, 3, 4, 5, 6, 7],
[1, 2, 3, 4, 5, 6, 7, 8],
[2, 3, 4, 5, 6, 7, 8, 9],
[3, 4, 5, 6, 7, 8, 9, 10],
[4, 5, 6, 7, 8, 9, 10, 11],
[5, 6, 7, 8, 9, 10, 11, 12],
[6, 7, 8, 9, 10, 11, 12, 13],
[7, 8, 9, 10, 11, 12, 13, 14]
]
# ...
def substitute(expanded_half):
output = ''
for i in range(0, len(expanded_half), 6):
chunk = expanded_half[i:i + 6]
row = int(chunk[:3], 2)
col = int(chunk[3:], 2)
val = S_BOXES[row][col]
output += format(val, '04b')
return output
# ...
data = input('Enter the plaintext you want to encrypt \n>')
cipher_text, feistel_output = des_ecb_encrypt(data, key)
print("Ciphertext = ", cipher_text)
print("Feistel output = ", feistel_output)
Exploitation
We can exploit this in the following way:
- Generate loads of plaintext/ciphertext pairs
- Because we have the feistel output in combination with the weird s-boxes, brute-force the last round key
- Brute the 8 bits of the initial round key that got lost during the permutations
- Decrypt the flag with the leaked round keys
#!/usr/bin/env python3
from pwn import *
from des_server_handout import *
from Crypto.Util.number import long_to_bytes
# === Custom DES functions === #
def generate_round_keys_from_partial(key):
round_keys = []
left_half = key[:28]
right_half = key[28:]
for round_num, shift_bits in enumerate(shift_table):
left_half = rotate_left(left_half, shift_bits)
right_half = rotate_left(right_half, shift_bits)
round_key = permute(left_half + right_half, PC2)
round_keys.append(round_key)
return round_keys
def des_encrypt(plain_text, key):
cipher_text = ''
round_keys = generate_round_keys_from_partial(key)
plain_text = permute(plain_text, IP)
left_half = plain_text[:32]
right_half = plain_text[32:]
for round_key in round_keys[:-1]:
feistel_output = feistel_network(right_half, round_key)
new_right_half = xor_strings(left_half, feistel_output[-1])
left_half = right_half
right_half = new_right_half
feistel_output = feistel_network(right_half, round_keys[-1])
new_right_half = xor_strings(left_half, feistel_output[-1])
left_half = right_half
right_half = new_right_half
cipher_text = permute(right_half + left_half, FP)
return cipher_text
def rotate_right(key, bits):
return key[-bits:] + key[:-bits]
def des_decrypt(ciphertext, key):
plain_text = ''
round_keys = generate_round_keys_from_partial(key)[::-1]
ciphertext = permute(ciphertext, IP)
left_half = ciphertext[:32]
right_half = ciphertext[32:]
for round_key in round_keys[:-1]:
feistel_output = feistel_network(right_half, round_key)
new_right_half = xor_strings(left_half, feistel_output[-1])
left_half = right_half
right_half = new_right_half
feistel_output = feistel_network(right_half, round_keys[-1])
new_right_half = xor_strings(left_half, feistel_output[-1])
left_half = right_half
right_half = new_right_half
plain_text = permute(right_half + left_half, FP)
return plain_text, feistel_output
def des_ecb_decrypt(ciphertext, key):
plain_text = ''
feistel_output = []
for i in range(0, len(ciphertext), 64):
block = ciphertext[i:i + 64]
curr_plain_text, curr_feistel_output = des_decrypt(block, key)
plain_text += curr_plain_text
feistel_output.append(curr_feistel_output)
return plain_text
# === Permutation === #
def inverse_permutation_table(table, length=None):
if length is None:
length = len(table)
inverse_table = [0] * length
for index, value in enumerate(table):
inverse_table[value-1] = index + 1
return inverse_table
def permute(block, table):
if isinstance(block, bytes):
block = block.decode()
res = ''
for i in table:
res = res + block[i - 1]
return res
# === Pwntools / Remote connection === #
FP_INV = inverse_permutation_table(FP)
P_INV = inverse_permutation_table(P)
def conn():
if args.LOCAL:
r = process(['python3', './des_server_handout.py'])
else:
r = remote('your-instance.ctf.m0unt41n.ch', 1337, ssl=True)
return r
def get_encrypted_flag(r):
r.sendlineafter(b'>', b'1')
r.recvuntil(b'Flag = ')
flag = r.recvuntil(b'\n').strip()
r.recvuntil(b'Feistel output = ')
feistel_output = eval(r.recvuntil(b'\n').strip())
return flag, feistel_output
def encrypt_message(r, msg):
r.sendlineafter(b'>', b'2')
r.sendlineafter(b'>', msg)
r.recvuntil(b'Ciphertext = ')
ct = r.recvuntil(b'\n').strip()
r.recvuntil(b'Feistel output = ')
feistel_output = eval(r.recvuntil(b'\n').strip())
return ct, feistel_output
# === Exploitation === #
def get_possible_sbox_inputs(substituted_half):
# Generate all possible inputs
substituted_half_int = int(substituted_half, 2)
possible_inputs = []
for i in range(0,8):
for j in range(0,8):
# S-Box just adds both halves, lol
if i + j == substituted_half_int:
possible_inputs.append(format(i, '03b') + format(j, '03b'))
return possible_inputs
def get_round_key_bits(substituted_half, expanded_half):
# Generate all possible keys, keep track of the amount of 0/1 per key bit
kb_0 = [0]*6
kb_1 = [0]*6
possible_inputs = get_possible_sbox_inputs(substituted_half)
for inp in possible_inputs:
# XOR with all possible sbox values, update key bit probability
key_bits = xor_strings(inp, expanded_half)
for i in range(6):
if key_bits[i] == '1':
kb_1[i] += 1
else:
kb_0[i] += 1
return kb_0, kb_1
def get_round_key(ct_pairs, feistel):
# Initialize both bit arrays
possible_key_bits_1_count = [0]*48
possible_key_bits_0_count = [0]*48
# Loop through all key pairs
for i in range(len(ct_pairs)):
permuted_half = feistel[i][-1]
# Undo initial permutation
substituted_half = permute(permuted_half, P_INV)
expanded_half = feistel[i][0]
# split into pairs
substituted_half_parts = [substituted_half[i:i+4] for i in range(0, len(substituted_half), 4)]
expanded_half_parts = [expanded_half[i:i+6] for i in range(0, len(expanded_half), 6)]
# Loop through every pair
for j in range(len(substituted_half_parts)):
sbox_output = substituted_half_parts[j]
kb_0, kb_1 = get_round_key_bits(sbox_output, expanded_half_parts[j])
key_index = 6*j
for k in range(6):
# Update global key bit state
possible_key_bits_0_count[key_index + k] += kb_0[k]
possible_key_bits_1_count[key_index + k] += kb_1[k]
key = ''
for i in range(48):
# Decide about key bits based on frequency
if possible_key_bits_1_count[i] > possible_key_bits_0_count[i]:
key += '1'
else:
key += '0'
return key
# === Main script === #
def main():
r = conn()
# Generate enough plaintext pairs
PT = cyclic(8*200)
ct, feistel = encrypt_message(r, PT)
# split the plaintext pairs into ECB blocks
ct_pairs = [ct[i:i+64] for i in range(0, len(ct), 64)]
for i in range(len(ct_pairs)):
# revert FP
ct_pairs[i] = permute(ct_pairs[i], FP_INV)
# get last round key
key_15 = get_round_key(ct_pairs, feistel)
success(key_15)
# get the approximate initial key
PC2_inv = inverse_permutation_table(PC2, 56)
key_15 = permute(key_15, PC2_inv)
left_half, right_half = key_15[:28], key_15[28:]
# undo all the shifts
for i in range(16):
idx = shift_table[i]
left_half = rotate_right(left_half, idx)
right_half = rotate_right(right_half, idx)
# encrypt a known ciphertext to later brute force
known_ct = encrypt_message(r, b"A"*8)[0][:64].decode()
# try out all initial keys
partial_key = list(left_half + right_half)
# Taken from the PC2 table
missing_bits = [8, 17, 21, 24, 34, 37, 42, 53]
for i in range(255):
binary = format(i, '08b')
for i in range(8):
partial_key[missing_bits[i]] = binary[i]
key = "".join(partial_key)
ct = des_encrypt(int_to_64b_bitstring(int.from_bytes(b"A"*8)), key)
if ct == known_ct:
success(f"Found key: {key}")
break
# get and decode the flag
flag = get_encrypted_flag(r)[0]
success(long_to_bytes(int(des_ecb_decrypt(flag, key),2)))
if __name__ == "__main__":
main()
Flag
Conclusion
I approached this challenge in the completely wrong way. Instead of just following my initial idea of recovering the last key, I went down rabbitholes of cryptanalysis.