from multiprocessing import Pool
from map_drop import map_drop
from numpy.random import seed as np_seed
load("../framework/instance_gen.sage")

Derr = build_centered_binomial_law(6)
modulus = 11
saving_results = True

try:
    N_tests = int(sys.argv[1])
    threads = int(sys.argv[2])
except:
    N_tests = 5
    threads = 1


def v(i):
    return canonical_vec(d, i)


qvec_donttouch = 20


def randv():
    vv = v(randint(qvec_donttouch, d - 1))
    vv -= v(randint(qvec_donttouch, d - 1))
    vv += v(randint(qvec_donttouch, d - 1))
    vv -= v(randint(qvec_donttouch, d - 1))
    vv += v(randint(qvec_donttouch, d - 1))
    return vv


def one_experiment(id, aargs):
    (N_hints, T_hints) = aargs
    mu, variance = average_variance(Derr)
    set_random_seed(id)
    np_seed(seed=id)
    A, b, dbdd = initialize_from_LWE_instance(DBDD, n, q,
                                              m, D_e, D_s,
                                              verbosity=0)
    A, b, dbdd_p = initialize_from_LWE_instance(DBDD_predict,
                                                n, q, m, D_e,
                                                D_s,
                                                verbosity=0)
    for j in range(N_hints):
        vv = randv()
        print(vv)
        if T_hints == "Perfect":
            dbdd.integrate_perfect_hint(vv, dbdd.leak(vv), estimate=False)
            dbdd_p.integrate_perfect_hint(vv, dbdd_p.leak(vv), estimate=False)
        if T_hints == "Approx":
            dbdd.integrate_approx_hint(vv, dbdd.leak(vv) +
                                       draw_from_distribution(Derr),
                                       variance, estimate=False)
            dbdd_p.integrate_approx_hint(vv, dbdd_p.leak(vv) +
                                         draw_from_distribution(Derr),
                                         variance, estimate=False)
        if T_hints == "Modular":
            dbdd.integrate_modular_hint(vv, dbdd.leak(vv) % modulus,
                                        modulus, smooth=True, estimate=False)
            dbdd_p.integrate_modular_hint(vv, dbdd_p.leak(vv) % modulus,
                                          modulus, smooth=True, estimate=False)
        if T_hints == "Q-Modular":
            dbdd.integrate_q_modular_hint(vv, dbdd.leak(vv) % q,
                                        q, estimate=False)
            dbdd_p.integrate_q_modular_hint(vv, dbdd_p.leak(vv) % q,
                                          q, estimate=False)

    dbdd_p.integrate_q_vectors(q, indices=range(20))
    dbdd.integrate_q_vectors(q, indices=range(20))
    beta_pred_light, _ = dbdd_p.estimate_attack(probabilistic=True)
    beta_pred_full, _ = dbdd.estimate_attack(probabilistic=True)
    beta, _ = dbdd.attack()
    return (beta, beta_pred_full, beta_pred_light)


dic = {" ": None}

def get_stats(data, N_tests):
    avg = RR(sum(data)) / N_tests
    var = abs(RR(sum([r**2 for r in data])) / N_tests - avg**2)
    return (avg, var)

def validation_prediction(N_tests, N_hints, T_hints):
    # Estimation
    import datetime
    ttt = datetime.datetime.now()
    res = map_drop(N_tests, threads, one_experiment, (N_hints, T_hints))
    beta_real, vbeta_real = get_stats([r[0] for r in res], N_tests)
    beta_pred_full, vbeta_pred_full = get_stats([r[1] for r in res], N_tests)
    beta_pred_light, vbeta_pred_light = get_stats([r[2] for r in res], N_tests)

    print("%d,\t %.3f,\t %.3f,\t %.3f \t\t" %
          (N_hints, beta_real, beta_pred_full, beta_pred_light), end=" \t")
    print("Time:", datetime.datetime.now() - ttt)
    
    if saving_results:
        with open('results.csv', 'a') as _file:
            _file.write(f'{T_hints};{N_hints};{N_tests};{datetime.datetime.now() - ttt};')
            _file.write(f'{beta_real};{beta_pred_full};{beta_pred_light};')
            _file.write(f'{vbeta_real};{vbeta_pred_full};{vbeta_pred_light};')
            _file.write(f'\n')
    return beta_pred_full


logging("Number of threads : %d" % threads, style="DATA")
logging("Number of Samples : %d" % N_tests, style="DATA")
logging("     Validation tests     ", style="HEADER")

n = 70
m = n
q = 3301
D_s = build_centered_binomial_law(40)
D_e = build_centered_binomial_law(40)
d = m + n

print("\n \n None")

print("hints,\t real,\t pred_full, \t pred_light,")

beta_pred = validation_prediction(N_tests, 0, "None")

print("\n \n Perfect")

print("hints,\t real,\t pred_full, \t pred_light,")
for h in range(1, 100):
    beta_pred = validation_prediction(N_tests, h, "Perfect")  # Line 0
    if beta_pred < 3:
        break

print("\n \n Modular")

print("hints,\t real,\t pred_full, \t pred_light,")
for h in range(2, 200, 2):
    beta_pred = validation_prediction(N_tests, h, "Modular")  # Line 0
    if beta_pred < 3:
        break
        
print("\n \n Q-Modular")

print("hints,\t real,\t pred_full, \t pred_light,")
for h in range(1, 100):
    beta_pred = validation_prediction(N_tests, h, "Q-Modular")  # Line 0
    if beta_pred < 3:
        break

print("\n \n Approx")

print("hints,\t real,\t pred_full, \t pred_light,")
for h in range(4, 200, 4):
    beta_pred = validation_prediction(N_tests, h, "Approx")  # Line 0
    if beta_pred < 3:
        break