SEA

Grégoire Frémion · May 12, 2024

Résolution

En analysant le code on déduit que l’on a affaire à un chiffrement symétrique par bloc de type Feistel, comportant 4 rounds, avec une clé de 256 bits et des clés de round de 64 bits (à noter que la key expansion est à base de shake256 et non réversible).

Puisqu’il y a peu de rounds et qu’il n’y a pas l’air d’avoir d’erreurs d’implémentations qui nuiraient à la sécurité du chiffrement, on peut penser assez rapidement à de la cryptographique différentielle ou linéaire par exemple.

Voici un schéma de la fonction de round utilisée par SEA (S = SBOX_1 et S2 = SBOX_2): sea round function

L’une des premières choses à faire est d’analyser les deux SBOX pour évaluer leur résistance à la cryptalyse linéaire/différentielle. Pour ça, on génère la Differential Distribution Table (DDT) et la Linear Approximation Table (LAT) des deux SBox (voici un exemple d’analyse de SBox), et on affiche ça dans un format agréable :

  • LAT de la première SBox : LAT1
  • DDT de la première SBox : DDT1

On peut faire la même chose pour la deuxième SBox. Rien ne semble choquant à priori sauf un point très foncé dans la DDT de la première SBox. En investigant un peu plus avec sage, on se rend compte que il y a une caractéristique différentielle vérifée avec une probabilité 1 pour la première SBox : 233 -> 3 (avec 0 -> 0). L’attaque semble se diriger vers la cryptanalyse différentielle.

Le seconde SBox ne possède rien de tel, on va se concentrer sur la première SBox. Il faut maintenant trouver une différentielle pour la fonction de round, il n’en existe pas qu’une mais certaines seront meilleures que d’autres au moment de récupérer les clés de round. En voici une : DIFF

Le but va être de casser la 4ème clé de round, puis la 3ème, … jusqu’à avoir toutes les clés et pouvoir déchiffrer le flag.

Voici les différentes différentielles utilisées pour casser les clés de rounds :

Une fois ces différentielles trouvées, la démarche est similaire à celle pour casser FEAL-4 par exemple. Il y a cependant quelques différences :

  • L’idée est de faire une différentielle tronquée pour casser les clés car faire un bruteforce sur le bloc de 64 bits entier est peine perdue. En effet, si l’on regarde la fonction de round, on se rend compte que on peut bruteforcer indépendaments les blocs d’octets en entrée (1,2,3), (4,5), (6,7,8) soit un bruteforce respectivement sur 24 bits, 16 bits et 24 bits, ce qui est largement faisable.
  • Plusieurs clés vont convenir, afin de retrouver les bonnes il faut utiliser la relation de la key expansion pour déterminer quelle suite de clés est la bonne

Une fois les clés de round 4, 3 et 2 en notre possession on peut déchiffrer jusqu’au deuxième round donc on peut avoir immédiatement la moitié droite de la clé de pre-whitening avec la relation K' = A xor B

Pour la clé du premier round, il suffit de prendre une différentielle non nulle et appliquer la même méthode que pour les autres clés. Enfin, on déchiffre le round 1 et on obtient la moitié gauche de la clé de pre-whitening sur le même principe que pour l’autre partie de la clé de pre-whitening.

La solution est implémentée en python3 mais n’est pas facile à lire/peu commentée :

# 404CTF 2024
# Challenge : SEA, hard
# Authors : acmo0

from multiprocessing import Pool
from functools import partial
import sys
import socket
import os
from Cryptodome.Util.number import long_to_bytes, bytes_to_long
from hashlib import shake_256
from sea import SEA


def get_right(c):
	return c&0xffffffffffffffff
def get_left(c):
	return c>>64

def try_key_range_last_rd(pairs,diff,f,mask,search_range):
	"""
	L'idée est de chiffrer la partie gauche du ciphertext pour chaque paire
	et regarder si l'on a la différentielle tronquée attendue. Si c'est le cas,
	la clé est un candidat potentiel
	"""
	candidates = []
	for candidate in search_range:
		valid = True
		for pair in pairs:

			a1,a2,plain1,plain2 = pair
			right1 = get_right(plain1)
			left1 = get_left(plain1)

			right2 = get_right(plain2)
			left2 = get_left(plain2)

			left1,right1 = right1,left1
			left2,right2 = right2,left2

			out_f1 = f(right1^candidate)
			out_f2 = f(right2^candidate)

			if ((out_f1^left1)^(left2^out_f2))&mask != (diff&mask):
				valid=False
				break
		if valid and candidate != 0:
			candidates.append(candidate)
	return candidates

def generate_pairs_with_diff(cipher,diff,pair_number=8):
	pairs = []
	for i in range(pair_number):
		plain1 = os.urandom(16)
		plain2 = long_to_bytes(bytes_to_long(plain1)^diff)
		if len(plain2) != 16:
			plain2 = (16-len(plain2))*b'\x00'+plain2
		pairs.append([bytes_to_long(plain1),bytes_to_long(plain2),bytes_to_long(cipher.encrypt(plain1)),bytes_to_long(cipher.encrypt(plain2))])

	return pairs

def crack_round(cipher,diff,diff_out,masks,pairs,N=8):
	ranges = [range(i,i+(2**24//N)) for i in range(0,2**24,(2**24//N))]	# Intervalles de recher des clés pour multi-threader
	
	pool = Pool(processes=N)

	# Récupérer les 3 derniers octets de la clé de round
	mask = masks[0]	# Masque pour la différentielle tronquée
	map_function = partial(try_key_range_last_rd,pairs,diff_out,cipher.f,mask)
	fhalf_key = []
	for candidate in pool.imap(map_function,ranges):
		for c in candidate:
			fhalf_key.append(c)
	
	if len(fhalf_key) == 0:
		return []

	# Récupère les deux octets du milieu de la clé
	ranges = [range(i,i+(2**40-2**24)//N,2**24) for i in range(2**24,2**40,((2**40-2**24)//N))] # Intervalles de recher des clés pour multi-threader
	mask = masks[1]
	map_function = partial(try_key_range_last_rd,pairs,diff_out,cipher.f,mask)
	middle_key = [] # Masque pour la différentielle tronquée
	for candidate in pool.imap(map_function,ranges):
		for c in candidate:
			middle_key.append(c)
	
	if len(middle_key) == 0:
		return []

	# Récupérer les 3 premiers octets de la clé de round
	ranges = [range(i,i+(2**64-2**40)//N,2**40) for i in range(2**40,2**64,((2**64-2**40)//N))] # Intervalles de recher des clés pour multi-threader
	mask = masks[2]
	map_function = partial(try_key_range_last_rd,pairs,diff_out,cipher.f,mask)
	lhalf_key = [] # Masque pour la différentielle tronquée
	for candidate in pool.imap(map_function,ranges):
		for c in candidate:
			lhalf_key.append(c)
	
	if len(lhalf_key) == 0:
		return []

	# Clés de round possibles
	possible_keys = [((l&0xffffff0000000000)+(m&0xffff000000) + f) for l in lhalf_key for f in fhalf_key for m in middle_key]
	
	return possible_keys


def decrypt_round_with_key(key,f,cipher_text):
	#print("Decripting",hex(cipher_text))
	left,right = get_left(cipher_text),get_right(cipher_text)
	left,right = right,left
	out_f = f(right^key)
	return ((left^out_f)<<64)+right

def decrypt_round_pairs(key,f,pairs):
	decrypted_pairs = []
	for pair in pairs:
		decrypted_pairs.append([pair[0],pair[1]]+[decrypt_round_with_key(key,f,e) for e in pair[2:]])
	return decrypted_pairs

def test_key_pair(key,next_key):
	return shake_256(long_to_bytes(key,blocksize=8)).digest(length=8) == long_to_bytes(next_key,blocksize=8)


class DistantCipher:	# Interface avec l'oracle de chiffrement
	def __init__(self,s):
		self.s = s
		self.SBox1=[
			0x24,0xc1,0x38,0x30,0xe7,0x57,0xdf,0x20,0x3e,0x99,0x1a,0x34,0xca,0xd6,0x52,0xfd,
			0x40,0x6c,0xd3,0x95,0x4a,0x59,0xf8,0x77,0x79,0x61,0x0a,0x56,0xb9,0xd2,0xfc,0xf1,
			0x07,0xf5,0x93,0xcd,0x00,0xb6,0xcb,0xa7,0x63,0x98,0x44,0xbd,0x5f,0x92,0x6b,0x73,
			0x3c,0x4e,0xa2,0x97,0x0b,0x01,0x83,0xa3,0xee,0xe5,0x45,0x67,0xf4,0x13,0xad,0x8b,
			0xbb,0x0c,0x72,0xb4,0x2a,0x3a,0xc5,0x84,0xec,0x9f,0x14,0xc0,0xc4,0x16,0x31,0xd9,
			0xab,0x9e,0x0e,0x1d,0x7c,0x48,0x1b,0x05,0x1c,0xea,0xa5,0xf0,0x8f,0x85,0x50,0x2c,
			0x35,0xbf,0x26,0x28,0x7b,0xe2,0xaa,0xf9,0x4f,0xe3,0xcc,0x2e,0x11,0x76,0xb1,0x8d,
			0xd4,0x5e,0xaf,0xe8,0x42,0xb0,0x6d,0x65,0x82,0x6a,0x58,0x8a,0xdd,0x7e,0x22,0xd8,
			0xe0,0x4c,0x2d,0xcf,0x75,0x12,0x8e,0xb2,0xbc,0x36,0x2b,0x25,0xe1,0x78,0xfa,0xa9,
			0x69,0x81,0x89,0x5b,0x7d,0xde,0xdb,0x21,0x5d,0xd7,0xeb,0xac,0xb3,0x41,0x66,0x6e,
			0x9c,0xef,0xc3,0x17,0x15,0xc7,0xda,0x32,0x0f,0xb8,0xb7,0x71,0x39,0x29,0x87,0xc6,
			0xe9,0x1f,0xf3,0xa6,0x86,0x8c,0x2f,0x53,0x9d,0xa8,0x1e,0x0d,0x4b,0x7f,0x06,0x18,
			0x9b,0x60,0xbe,0x47,0x91,0x5c,0x70,0x68,0xf6,0x04,0xce,0x90,0xb5,0x03,0xa4,0xc8,
			0xe6,0xed,0x64,0x46,0x10,0xf7,0x88,0xae,0x4d,0x3f,0x94,0xa1,0x02,0x08,0xa0,0x80,
			0x9a,0x3d,0x37,0x19,0xd5,0xc9,0xfe,0x51,0xc2,0x27,0x33,0x3b,0x54,0xe4,0x23,0xdc,
			0x62,0x7a,0x55,0x09,0xd1,0xba,0xf2,0xff,0x6f,0x43,0x96,0xd0,0x5a,0x49,0x74,0xfb,
		]
		self.SBox2=[
			0x24,0xc1,0x38,0x30,0xe7,0x57,0xdf,0x20,0x3e,0x99,0x1a,0x34,0xca,0xd6,0x52,0xfd,
			0x40,0x6c,0xd3,0x3d,0x4a,0x59,0xf8,0x77,0xfb,0x61,0x0a,0x56,0xb9,0xd2,0xfc,0xf1,
			0x07,0xf5,0x93,0xcd,0x00,0xb6,0x62,0xa7,0x63,0xfe,0x44,0xbd,0x5f,0x92,0x6b,0x68,
			0x03,0x4e,0xa2,0x97,0x0b,0x60,0x83,0xa3,0x02,0xe5,0x45,0x67,0xf4,0x13,0x08,0x8b,
			0x10,0xce,0xbe,0xb4,0x2a,0x3a,0x96,0x84,0xc8,0x9f,0x14,0xc0,0xc4,0x6f,0x31,0xd9,
			0xab,0xae,0x0e,0x64,0x7c,0xda,0x1b,0x05,0xa8,0x15,0xa5,0x90,0x94,0x85,0x71,0x2c,
			0x35,0x19,0x26,0x28,0x53,0xe2,0x7f,0x3b,0x2f,0xa9,0xcc,0x2e,0x11,0x76,0xed,0x4d,
			0x87,0x5e,0xc2,0xc7,0x80,0xb0,0x6d,0x17,0xb2,0xff,0xe4,0xb7,0x54,0x9d,0xb8,0x66,
			0x74,0x9c,0xdb,0x36,0x47,0x5d,0xde,0x70,0xd5,0x91,0xaa,0x3f,0xc9,0xd8,0xf3,0xf2,
			0x5b,0x89,0x2d,0x22,0x5c,0xe1,0x46,0x33,0xe6,0x09,0xbc,0xe8,0x81,0x7d,0xe9,0x49,
			0xe0,0xb1,0x32,0x37,0xea,0x5a,0xf6,0x27,0x58,0x69,0x8a,0x50,0xba,0xdd,0x51,0xf9,
			0x75,0xa1,0x78,0xd0,0x43,0xf7,0x25,0x7b,0x7e,0x1c,0xac,0xd4,0x9a,0x2b,0x42,0xe3,
			0x4b,0x01,0x72,0xd7,0x4c,0xfa,0xeb,0x73,0x48,0x8c,0x0c,0xf0,0x6a,0x23,0x41,0xec,
			0xb3,0xef,0x1d,0x12,0xbb,0x88,0x0d,0xc3,0x8d,0x4f,0x55,0x82,0xee,0xad,0x86,0x06,
			0xa0,0x95,0x65,0xbf,0x7a,0x39,0x98,0x04,0x9b,0x9e,0xa4,0xc6,0xcf,0x6e,0xdc,0xd1,
			0xcb,0x1f,0x8f,0x8e,0x3c,0x21,0xa6,0xb5,0x16,0xaf,0xc5,0x18,0x1e,0x0f,0x29,0x79,
		]

	def encrypt(self,block):
		self.s.send(block.hex().encode('utf-8')+b"\n")
		return bytes.fromhex(self.s.recv(1024).decode('utf-8').split("\n")[0])
	def f(self,block):
		b1 = (block>>56)
		b2 = (block>>48) & 0xff
		b3 = (block>>40) & 0xff
		b4 = (block>>32) & 0xff
		b5 = (block>>24) & 0xff
		b6 = (block>>16) & 0xff
		b7 = (block>>8) & 0xff
		b8 = block & 0xff

		b2 ^= b3
		b1 ^= b2
		b1 = self.SBox1[b1]
		b2 ^= b1
		b2 = self.SBox2[b2]
		b3 ^= b2
		b3 = self.SBox2[b3]
		b3 ^= b1
		b4 ^= b5
		b4 = self.SBox2[b4]
		b5 ^= b4
		b5 = self.SBox1[b5]
		b7 ^= b6
		b6 = self.SBox1[b6]
		b7^= b6
		b7 = self.SBox2[b7]
		b8 ^= b7
		b6 ^= b7
		b8 = self.SBox1[b8]

		return (b2<<56)+(b3<<48)+(b6<<40)+(b1<<32)+(b4<<24)+(b8<<16)+(b5<<8)+b7



HOST="XXXXXX"
PORT=?

#Connection au serveur
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((HOST,PORT))
recv = s.recv(1024).decode('utf-8')

encrypted_flag = bytes.fromhex(recv.split("\n")[-2].split(" : ")[1])

cipher = DistantCipher(s)
keys = []

# 4ème clé
d1 = 0xea0300e9e9e9eae9
d2 = 0
diff=(d1<<64)+d2
diff_out = 0x0003030300030300
masks = (0xff0000ff00ff,0xff00ff00,0xffff00ff00000000)
pairs = generate_pairs_with_diff(cipher,diff,pair_number=4)
keys_last_rd = set(crack_round(cipher,diff,diff_out,masks,pairs))

print("Found",len(keys_last_rd),"potential keys")

# 3ème clé
d1 = 0x0000000000000000
d2 = 0xea0300e9e9e9eae9
diff=(d1<<64)+d2
diff_out = 0x0003030300030300
pairs = generate_pairs_with_diff(cipher,diff,pair_number=4)
decrypted_last_round = None
key_third_rd = None

for key in keys_last_rd:	# Pour chaque clé potentielle trouvée au round 4
	
	decrypted_last_round = decrypt_round_pairs(key,cipher.f,pairs)	# On déchiffre avec la clé 4 supposée
	founds = set(crack_round(cipher,diff,diff_out,masks,decrypted_last_round))
	for k in founds:	# On vérifie que la relation de la key expansion est vérifiée, si oui on a la bonne clé à priori
		if test_key_pair(k,key):
			print("Found :",hex(k))
			key_third_rd = k
			keys = [key,k]
	if key_third_rd != None:
		break

# Clé 2
d1 = 0
d2 = 0x0003030300030300
diff=(d1<<64)+d2
diff_out = 0x0003030300030300
masks = (0xff0000ff00ff,0xff00ff00,0xffff00ff00000000)
pairs = generate_pairs_with_diff(cipher,diff,pair_number=4)
decrypted_third_round = decrypt_round_pairs(keys[1],cipher.f,decrypt_round_pairs(keys[0],cipher.f,pairs))	# On déchiffre le 4ème et 3ème round
key_second_rd = None

for key in crack_round(cipher,diff,diff_out,masks,decrypted_third_round):	# On garde la clé qui vérifie la relation de key expansion
	if test_key_pair(key,keys[-1]):
		print("Found !",hex(key))
		key_second_rd = key
		break
keys.append(key_second_rd)

# Clé 1
d1 = 0x0003030300030300
d2 = 0x0101010101010101
diff=(d1<<64)+d2
diff_out = 0x0003030300030300
pairs = generate_pairs_with_diff(cipher,diff,pair_number=4)

# Déchiffre les rounds 4, 3 et 2
decrypted_second_round = decrypt_round_pairs(keys[2],cipher.f,decrypt_round_pairs(keys[1],cipher.f,decrypt_round_pairs(keys[0],cipher.f,pairs)))
right_half_init_key = get_left(decrypted_second_round[0][2])^get_right(decrypted_second_round[0][0]) # Récupère le 

print("recovered right half of the init key",hex(right_half_init_key))
key_first_rd = None
for key in crack_round(cipher,diff,diff_out,masks,decrypted_second_round):	# Récupère la clé 1
	if test_key_pair(key,keys[-1]):
		key_first_rd = key
		break

keys.append(key_first_rd)

# Déchiffre tous les rounds
decrypted_first_round = decrypt_round_pairs(keys[3],cipher.f,decrypt_round_pairs(keys[2],cipher.f,decrypt_round_pairs(keys[1],cipher.f,decrypt_round_pairs(keys[0],cipher.f,pairs))))

# Récupère la partie gauche de la clé de pre-whitening
left_half_init_key = get_left(decrypted_first_round[0][2])^get_left(decrypted_first_round[0][0])
keys.append((left_half_init_key<<64)+right_half_init_key)

print("Recovered keys :",[hex(e) for e in keys][::-1])


cipher2 = SEA(SBOX_1,SBOX_2,os.urandom(32))
cipher2.round_keys = keys
print(cipher2.decrypt(encrypted_flag)) # Affiche le flag

Twitter, Facebook