SharifCTF 7 – Lobotomized LSB Oracle

[Note: this is a follow up to the challenge LSB Oracle. It may be useful to view the writeup for that challenge first.]

In LSB Oracle, we were given an oracle that would compute the least significant bit of decryptions of inputs we feed it. In Lobotomized LSB Oracle, we are given the same oracle, but with noise! It’s important to note that this noise is consistent per input (i.e., some fraction of all inputs give the wrong answer), so we can’t simply correct for it by repeating each query many times and taking the majority response.

While trying this challenge, I found a large series of papers which explores this exact problem (usually, this is phrased in terms of proving it is at least as hard to construct an LSB Oracle as it is to decrypt RSA). For a brief overview of what is known, I recommend reading Section 7 of this survey. In general, you can solve this Lobotomized LSB Oracle problem even when each answer is wrong with probability \frac{1}{2} + \frac{1}{\mathrm{poly}(n)}.

However, in this challenge, the probability the oracle answers incorrectly is way lower than this. In fact, the exact probability was too small to accurately estimate; on 1000 random encryptions, it answered correctly 1000 times. It answers incorrectly often enough that the solution to LSB Oracle manages to recover a good chunk of the flag. As long as we have any way to correct these super-rare errors, we can just run our original solution.

One way would be to manually try to guess when the Oracle is making a mistake, manually flipping its bit, and seeing if we get more of the flag that way. We can do something more general, however, by taking advantage of the self-reducibility of RSA.

Recall that (by multiplying the ciphertext by k^e), we can use the LSB oracle to answer questions of the form “is kf \bmod N even or odd?”, where f is the plaintext flag and N is the modulus. The oracle may respond incorrectly to this, so we would like to check its answer by asking it another question that ideally should have the same answer. One good question to try is “is (k+2)f \bmod N even or odd?”.

Since 2f is even, kf \bmod N and (k+2)f \bmod N should have the same parity unless there is a multiple of N between them. Luckily, it turns out that f is quite small compared to N, with N/f \approx 16000. This means it’s very unlikely (~0.01%) that a multiple of N will lie between kf and (k+2)f.

This trick also works for (k+2i) in general, as long as 2i is not too big compared to 16000. We can therefore correct errors by say, taking the majority of the oracle’s answers on k through (k+8). The following code implements this approach.

from pwn import *
from Crypto.Util.number import *

N = 94169898764475155086179365872915864925768243050855426387910613522303337327416930459077578555524838413579345103633071500300104580298306187507383687796776619261744561887287065152410825040924957174425131901014950571780211869823508452987101620679856181308669517708916215765377471785309709279780997993371462202127
C = 84554310261580598058211620872297995265063480196893812976334022270327838015482739129096939702314740821259766144865677921673974339162910708930818463109733348984687023660294660726179053438750361754457786927212462355725758670143043124242928370865662017903815787388480232771504943423128214544949007416507395402507
E = 65537

p = process(['lobotomized_lsb_oracle.vmp.exe', '/decrypt'])
p.recvline()
p.recvline()
p.recvline()
p.recvline()
p.recvline()

def decrypt_bit(x):
    p.sendline(str(x))
    p.recvline()
    return int(p.recvline())


# find error probability of oracle
TOTAL = 1000
success = 0
for _ in xrange(TOTAL):
    v = getRandomRange(1, N)
    encv = pow(v, E, N)
    b = decrypt_bit(encv)
#    print v%2, b
    if b==(v%2):
        success+=1

print success, '/', TOTAL


int_st = 0
int_end = N

def get_first(st, end, f):
    # finds first x in [st, end-1] s.t. (p*x)%N > N/2
    left = st
    right = end

    while right-left > 1:
        mid = (right+left)/2
        if (f*mid)%N > N/2:
            right = mid
        else:
            left = mid
    return right

pow2 = pow(2, E, N)

cpow = pow2
f = 1

def parity_check(k):
    # try to find parity of k*P
    # fix noise by taking majority over [k, k+2, k+4, k+6, k+8]
    vals = [(pow(k+2*i, E, N)*C)%N for i in xrange(5)]
    bits = [decrypt_bit(val) for val in vals]
    return 1 if sum(bits) > 2 else 0

while int_end - int_st > 1:
    imid = get_first(int_st, int_end, f)

    curstr = long_to_bytes(int_st)
    print curstr
    print N/(int_end) # how big is N/flag?


    if parity_check(2*f)==0:
        int_end = imid
    else:
        int_st = imid

    f *= 2

print int_st
print long_to_bytes(int_st)

#5903101931477662455816102201567891605708265729448307752561406636501574632073911483253095075622760142998917634536184890040384933717551952184462545850638876130512058253708637281335591022945131553883548565244629945021422515095852683813088365728589098264874275817445161298719022566193826136024045395632376260
Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s