snakeCTF logo

Yet Another Videogame Client

REV

1 file available


Description

Come and use our SNAKEstore for all your player needs. You will be able to play all¹ videogames from just one app² and with top-notch³ encryption!

¹ may not be all
² payment may be required
³ handcrafted by our intern

Uff, there are too many videogames clients these days, another one just came out! A friend of mine managed to infiltrate their production server and inject some code to steal their data, but he won't tell me how to execute it.

Solution

Reversing pyinstaller

We are given a binary file that can be quickly recognised as a file generated by pyinstaller after a quick look at the strings inside. We can use pyinstxtractor to try to extract the original pyc files, but an error message pops up:

[!] Error: Failed to decompress PYZ-00.pyz_extracted/Crypto/__init__.pyc, probably encrypted. Extracting as is.
[!] Error: Failed to decompress PYZ-00.pyz_extracted/Crypto/Cipher/__init__.pyc, probably encrypted. Extracting as is.
[!] Error: Failed to decompress PYZ-00.pyz_extracted/Crypto/Cipher/AES.pyc, probably encrypted. Extracting as is.
[!] Error: Failed to decompress PYZ-00.pyz_extracted/Crypto/Cipher/_EKSBlowfish.pyc, probably encrypted. Extracting as is.
[!] Error: Failed to decompress PYZ-00.pyz_extracted/Crypto/Cipher/_mode_cbc.pyc, probably encrypted. Extracting as is.
...

Trying to extract one of the encrypted files by hand gives the error

zlib.error: Error -3 while decompressing data: incorrect data check

So the crc did not match the data, but the extracted data looks like a correct pyc file, so we just tell pyinstxtractor to ignore the crc errors and we get:

[+] Successfully extracted pyinstaller archive: server

You can now use a python decompiler on the pyc files within the extracted directory

Reversing pyc

Using pycdc to extract server.pyc return almost a perfect python script, the only errors are given by the use of a match and lambda statement that can be fixed by hand looking at the disassembled code. The result is this

Building a client

We now need to recreate a client to speak to the server, looking at the recive method we know that each message is in the form

10 bytes 2 bytes length 4 bytes
Y3J5PnB3bg length encdata crc(encdata)

encryption is done with AES_CBC and a key decided in an initial handshake with ECDH.

Talking with the server, we are offered various options, even a dump_flag that does not work, and we notice that after closing the connection a strange function called h4cK3r is called.

This function refreshes the encryption key to a random value and then sends the flag to us.

Breaking random

We have to recover the internal state of random to be able to predict what the new key will be; luckily, we are given a "guess the number" minigame which will give us 96 bits, but we can only use it 10 times.

But wait the server does not use the normal python random package, instead it uses a library called SecureRandom which does not appear to be publicly available, we have to extract it from the binary file and pyinstxtractor already did it for us.

Opening the module with Ghidra, we can see that the algorithm to generate the random numbers is almost the same as a normal random python module, just some of the constants are different.

Python uses a Mersenne Twister algorithm to generate random numbers, which means that it uses 2 (and one bit) of the values generated in the past to get a new number. Since we need to predict 256 bits to get the key or 8 numbers of 32 bits we need at least 16 numbers in the correct position in the internal array to predict them.

We can use "guess the number hard" to get the numbers (in total we can get 10 96 bits numbers) and "guess the number easy" to position ourselves in the correct position.

Breaking Crypto

There are multiple writeups on how to break a Mersenne Twister prng, do I really have to write another one? You can find the solve script here

Get flag

After predicting the random value, we derive the correct key and decrypt the last message sent from the server to get the flag.

server.py


from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad, pad
import functools
import SecureRandom as random
import sys
import os
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.kdf.hkdf import HKDF

FLAG = os.environ.get('FLAG', 'fakeCTF{This_is_a_real_flag_i_swear}')


class challenge:

    def __init__(self):
        self.masterkey = b'ciaociaociaociao'
        self.lost = 0
        self.secure = False

    def refresh_mk(self, Alice_public_key):
        private_key = ec.generate_private_key(ec.SECP384R1())
        shared = private_key.exchange(ec.ECDH(), Alice_public_key)
        self.send(private_key.public_key().public_bytes(encoding=serialization.Encoding.PEM,
                                                        format=serialization.PublicFormat.SubjectPublicKeyInfo))
        self.masterkey = HKDF(algorithm=hashes.SHA256(), length=32, salt=None, info=b'handshake data').derive(shared)
        self.secure = True

    def play_rps(self):
        self.sendline(b'choose rock, paper or scissors, write "end" to finish playing')
        won = 0
        lost = 0
        tied = 0
        while True:
            p1 = self.receive().decode()
            if p1 == 'end':
                self.sendline(f'''you won {int((won / (won + lost + tied)) * 100)}% of the times'''.encode())
                return None
            p2 = random.choice([
                'rock',
                'paper',
                'scissors',
                'pistol'])
            if (p1, p2) in (('rock', 'scissors'), ('scissors', 'paper'), ('paper', 'rock')):
                self.sendline(b'congratulations!')
                won += 1
            elif p1 == p2:
                self.sendline(b'tie!')
                tied += 1
            else:
                self.sendline(b'try again')
                lost += 1
            continue

    def secret(self):
        self.sendline(b"tell me a secret, don't worry I'll take good care of it")
        secret = self.receive()
        self.sendline(b"interesting, I won't forget")
        open(os.devnull, 'wb').write(secret)

    def guess_number(self):
        self.sendline(b'try to guess the number I am thinking')
        n = random.randrange(0, 100)
        g = int(self.receive().decode())
        if n == g:
            self.sendline(b'congratulations')
            return None
        self.sendline(f'''try again, the number was {n}'''.encode())

    def guess_number_hard(self):
        if self.lost >= 10:
            self.sendline(b'maybe you should try with something easier first')
        self.sendline(b'try to guess the number I am thinking')
        n = random.getrandbits(96)
        g = int(self.receive().decode())
        if n == g:
            self.sendline(b'congratulations')
            return None
        self.sendline(f'''try again, the number was {n}'''.encode())
        self.lost += 1

    def not_implemented(self):
        self.sendline(b'You have to pay for that! ')

    def get_production_secret():
        return FLAG

    get_production_secret = staticmethod(get_production_secret)

    def dump_flag(self):
        # self.sendline(self.get_production_secret().translate((lambda .0: pass# WARNING: Decompyle incomplete)(
        # range(255))).encode())
        pass

    def h4cK3r(self):
        self.refresh_mk(ec.derive_private_key(random.getrandbits(256), ec.SECP384R1()).public_key())
        self.sendline(self.get_production_secret().encode())

    @staticmethod
    def crc(b):
        b = pad(b, 4)
        return functools.reduce(lambda x, y: x ^ y,
                                [int.from_bytes(b[i:i + 4], 'big') for i in range(0, len(b), 4)],
                                0).to_bytes(4, 'little')

    def check_crc(self, p, crc):
        return crc == self.crc(p)

    def decrypt(self, cipher_text, iv):
        cipher = AES.new(key=self.masterkey, iv=iv, mode=AES.MODE_CBC)
        return unpad(cipher.decrypt(cipher_text), 16)

    def encrypt(self, data):
        cipher = AES.new(key=self.masterkey, mode=AES.MODE_CBC)
        cipher_text = cipher.encrypt(pad(data, 16))
        return cipher_text + cipher.iv

    def decodePacket(self, p):
        cipher_text = p[:-4]
        crc = p[-4:]
        if not self.check_crc(cipher_text, crc):
            return None
        return self.decrypt(cipher_text[:-16], cipher_text[-16:])

    def encodePacket(self, data):
        if len(data) >= 65536:
            raise Exception('too much data')
        packet = b''
        data = self.encrypt(data)
        packet += len(data).to_bytes(2, 'little')
        packet += data
        packet += self.crc(data)
        return packet

    def receive(self):
        magic = sys.stdin.buffer.read(10)
        if magic != b'Y3J5PnB3bg':
            raise Exception('malformed packet' + str(magic))
        l = sys.stdin.buffer.read(2)
        l = int.from_bytes(l, 'little')
        packet = sys.stdin.buffer.read(l + 4)
        message = self.decodePacket(packet)
        return message

    def send(self, m):
        out = self.encodePacket(m)
        sys.stdout.buffer.write(out)
        sys.stdout.flush()

    def sendline(self, m):
        self.send(m + b'\n')

    def mainLoop(self):
        # WARNING: Decompile incomplete
        try:

            while True:
                self.send(b'give me a command\n >')
                order = self.receive()

                if not self.secure and order != b'handshake':
                    self.sendline(b'we need to make sure we are safe first\n')
                    continue

                match order:
                    case b'handshake':
                        self.refresh_mk(serialization.load_pem_public_key(self.receive()))
                    case b'rock_paper_scissors':
                        self.play_rps()
                    case b'minecraft':
                        self.not_implemented()
                    case b'chess':
                        self.not_implemented()
                    case b'poker':
                        self.not_implemented()
                    case b'guess':
                        self.guess_number()
                    case b'hard_guess':
                        self.guess_number_hard()
                    case b'secret':
                        self.secret()
                    case b'get_flag':
                        self.dump_flag()
                    case b'exit':
                        break
                    case _:
                        raise Exception('unknown command')

            self.h4cK3r()
        except Exception as e:
            print(e)

            # some strange things with e


if __name__ == '__main__':
    c = challenge()
    c.mainLoop()

solver.py

#!/usr/bin/env python3
from pwn import *
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from Crypto.Util.Padding import unpad, pad
from Crypto.Cipher import AES
import functools

HOST = args.HOST if args.HOST else "rev.snakectf.org"
PORT = args.PORT if args.PORT else 1501
p = remote(HOST, PORT)


class RandCrack:
    def __init__(self, comunicator):
        self.N = 624
        self.M = 397

        self.magia = [0, 0xB3A21A75]
        self.UPPER_MASK = 0x80000000
        self.LOWER_MASK = 0x7FFFFFFF

        self.mt = [None] * self.N
        self.nex = [None] * self.N
        self.c = comunicator

    def waste(self, n):
        self.c.receive()
        self.c.send(b"rock_paper_scissors")
        self.c.receive()
        for i in range(n):
            self.c.send(b"rock")
            self.c.receive()
        self.c.send(b"end")
        self.c.receive()

    def getbits(self):
        self.c.receive()
        self.c.send(b"hard_guess")
        self.c.receive()
        self.c.send(b"0")
        x = int(self.c.receive().decode().strip().split()[-1])
        for _ in range(3):
            yield self.untemper(x & ((1 << 32) - 1))
            x >>= 32
        return

    @staticmethod
    def unshiftRight(x, shift):
        res = x
        for i in range(32):
            res = x ^ res >> shift
        return res

    @staticmethod
    def unshiftLeft(x, shift, mask):
        res = x
        for i in range(32):
            res = x ^ (res << shift & mask)
        return res

    def untemper(self, v):
        v = self.unshiftRight(v, 18)
        v = self.unshiftLeft(v, 15, 0xC56C0000)
        v = self.unshiftLeft(v, 7, 0xB786FC80)
        v = self.unshiftRight(v, 11)
        return v

    @staticmethod
    def temper(y):
        y ^= y >> 11
        y ^= (y << 7) & 0xB786FC80
        y ^= (y << 15) & 0xC56C0000
        y ^= y >> 18
        return y

    def predict_key(self):
        i = 0
        while i < 12:  # 4 uses
            for x in self.getbits():
                self.mt[i] = x
                i += 1

        self.waste(386)  # reach next twist +1

        i = 1
        while i < 10:  # 3 uses
            for x in self.getbits():
                self.nex[i] = x
                i += 1

        self.waste(160)  # M * 2 - N

        i = self.M * 2 - self.N
        while i < self.M * 2 - self.N + 9:  # 3 uses
            for x in self.getbits():
                self.nex[i] = x
                i += 1

        self.waste(219)  # next twist +1

        for i in range(1, 10):
            y = (self.mt[i] & self.UPPER_MASK) | (self.mt[i + 1] & self.LOWER_MASK)
            test = self.nex[i] ^ (y >> 1) ^ self.magia[y & 1]

            self.mt[self.M + i] = test

        self.mt[self.M] = 0  # maybe 0 or 1<<31 makes no difference for us
        for i in range(self.M, self.M + 9):
            y = (self.mt[i] & self.UPPER_MASK) | (self.mt[i + 1] & self.LOWER_MASK)
            test = self.nex[i + self.M - self.N] ^ (y >> 1) ^ self.magia[y & 1]

            self.nex[i] = test

        predict = []

        for i in range(1, 9):
            y = (self.nex[i] & self.UPPER_MASK) | (self.nex[i + 1] & self.LOWER_MASK)
            test = self.nex[i + self.M] ^ (y >> 1) ^ self.magia[y & 1]
            predict.append(test)

        return predict


class Comunicator:
    def __init__(self):
        self.masterkey = b"ciaociaociaociao"

    @staticmethod
    def crc(b: bytes):
        b = pad(b, 4)
        return functools.reduce(
            lambda x, y: x ^ y,
            [int.from_bytes(b[i: i + 4], "big") for i in range(0, len(b), 4)],
            0,
        ).to_bytes(4, "little")

    def check_crc(self, p, crc):
        return crc == self.crc(p)

    def decrypt(self, cipher_text, iv):
        cipher = AES.new(key=self.masterkey, iv=iv, mode=AES.MODE_CBC)

        return unpad(cipher.decrypt(cipher_text), 16)

    def encrypt(self, data):
        cipher = AES.new(key=self.masterkey, mode=AES.MODE_CBC)
        cipher_text = cipher.encrypt(pad(data, 16))

        return cipher_text + cipher.iv

    def decodePacket(self, p: bytes):
        cipher_text = p[:-4]
        crc = p[-4:]

        if not self.check_crc(cipher_text, crc):
            return None

        return self.decrypt(cipher_text[:-16], cipher_text[-16:])

    def encodePacket(self, data: bytes):
        if len(data) >= 1 << 16:
            raise Exception("too much data")

        packet = b"Y3J5PnB3bg"
        data = self.encrypt(data)
        packet += len(data).to_bytes(2, "little")

        packet += data
        packet += self.crc(data)

        return packet

    def send(self, m):
        p.send(self.encodePacket(m))

    def receive(self):
        magic = p.recv(10)
        if magic != b"Y3J5PnB3bg":
            print("something went wrong", magic)
            p.recvall()
            return None

        l = p.recv(2)
        l = int.from_bytes(l, "little")

        packet = p.recv(l + 4)

        message = self.decodePacket(packet)
        return message

    def handshake(self):
        private_key = ec.generate_private_key(ec.SECP384R1())

        self.send(
            private_key.public_key().public_bytes(
                encoding=serialization.Encoding.PEM,
                format=serialization.PublicFormat.SubjectPublicKeyInfo,
            )
        )

        message = self.receive()

        A = serialization.load_pem_public_key(message)

        shared = private_key.exchange(ec.ECDH(), A)
        masterkey = HKDF(
            algorithm=hashes.SHA256(),
            length=32,
            salt=None,
            info=b"handshake data",
        ).derive(shared)

        self.masterkey = masterkey


c = Comunicator()


def local_handshake(private_key):
    message = c.receive()

    A = serialization.load_pem_public_key(message)

    shared = private_key.exchange(ec.ECDH(), A)
    masterkey = HKDF(
        algorithm=hashes.SHA256(),
        length=32,
        salt=None,
        info=b"handshake data",
    ).derive(shared)

    c.masterkey = masterkey


def solve():
    c.receive()
    c.send(b"handshake")
    c.handshake()

    r = RandCrack(c)
    k = r.predict_key()

    key = 0
    for i in range(8):
        key += r.temper(k[i]) << (32 * i)

    c.receive()
    c.send(b"exit")

    local_handshake(ec.derive_private_key(key, ec.SECP384R1()))
    flag = c.receive()
    print(flag)


if __name__ == "__main__":
    solve()