load("../framework/instance_gen.sage")
demo = False

"""  Example
Uncomment the following to get an example
of the detailed computation (without redundancy)
"""
demo = True
logging("--- Demonstration mode (no averaging) ---")


for params in ['CCS1', 'CCS2', 'CCS3', 'CCS4', 'NIST1', 'NIST2']:
    logging("Set of parameters: " + params)

    if params == 'NIST1':
        # NIST1 FRODOKEM-640
        n = 640
        m = 640
        q = 2**15
        frodo_distribution = [9288, 8720, 7216, 5264, 3384,
                              1918, 958, 422, 164, 56, 17, 4, 1]
        D_s = get_distribution_from_table(frodo_distribution, 2 ** 16)
        load("Frodo_Single_data/simulation_distribution_NIST1.sage")
        load("Frodo_Single_data/aposteriori_distribution_NIST1.sage")

    elif params == 'NIST2':
        # NIST2 FRODOKEM-976
        n = 976
        m = 976
        q = 65536
        frodo_distribution = [11278, 10277, 7774, 4882, 2545, 1101,
                              396, 118, 29, 6, 1]
        D_s = get_distribution_from_table(frodo_distribution, 2 ** 16) 
        load("Frodo_Single_data/simulation_distribution_NIST2.sage")
        load("Frodo_Single_data/aposteriori_distribution_NIST2.sage")

    elif params == 'CCS1':
        n = 352
        m = 352
        q = 2 ** 11
        frodo_distribution = [22528, 15616, 5120, 768]
        D_s = get_distribution_from_table(frodo_distribution, 2 ** 16)
        load("Frodo_Single_data/simulation_distribution_CCS1.sage")
        load("Frodo_Single_data/aposteriori_distribution_CCS1.sage")

    elif params == 'CCS2':
        n = 592
        m = 592
        q = 2 ** 12
        frodo_distribution = [25120, 15840, 3968, 384, 16]
        D_s = get_distribution_from_table(frodo_distribution, 2 ** 16) 
        load("Frodo_Single_data/simulation_distribution_CCS2.sage")
        load("Frodo_Single_data/aposteriori_distribution_CCS2.sage")

    elif params == 'CCS3':
        n = 752
        m = 752
        q = 2 ** 15
        frodo_distribution = [19296, 14704, 6496, 1664, 240, 16]
        D_s = get_distribution_from_table(frodo_distribution, 2 ** 16) 
        load("Frodo_Single_data/simulation_distribution_CCS3.sage")
        load("Frodo_Single_data/aposteriori_distribution_CCS3.sage")

    elif params == 'CCS4':
        n = 864
        m = 864
        q = 2 ** 15
        frodo_distribution = [19304, 14700, 6490, 1659, 245, 21, 1]
        D_s = get_distribution_from_table(frodo_distribution, 2 ** 16) 
        load("Frodo_Single_data/simulation_distribution_CCS4.sage")
        load("Frodo_Single_data/aposteriori_distribution_CCS4.sage")


    """  Original Security   """

    A, b, dbdd = initialize_from_LWE_instance(DBDD_predict_diag,
                                               n,
                                               q, m, D_s, D_s, verbosity=0)
    dbdd.integrate_q_vectors(q, indices=range(n, n + m))
    (beta, _) = dbdd.estimate_attack()
    logging("Attack without hints:  %3.2f bikz" % beta, style="HEADER")

    """  Refined Side channel attack  """

    def simu_measured(secret):
        """
        This fonction simulates the information gained by
        Bos et al attack. The simulation is based on a
        distribution obtained with a large amount of data
        for Bos et al suite (in Matlab).
        :secret: an integer being the secret value
        :measurement: an integer that represents the output
        of Bos et al attack.
        """
        secret = recenter(secret)
        distrib_of_guesses = renormalize_dist(Dguess[secret])
        measurement = draw_from_distribution(distrib_of_guesses)
        return measurement


    def ordered_indices(sorted_guesses, measured):
            """
            Necessary for the bruteforce attack, this function
            sorts the indices of the coefficients
            of the secret with decreasing likelihood.
            :sorted_guess: the best guesses in order of likelihood
            :measured: the measurement for each coefficient
            :orderered_coefficients: the indices of the coefficients
            ordered according to Probability[secret[i] = measured[i]]
            """
            orderered_coefficients = []
            for x in sorted_guesses:
                for i in range(len(measured)):
                    meas = measured[i]
                    if meas == x:
                        orderered_coefficients += [i]
            return orderered_coefficients


    def estimate_SCA(report_every, dbdd, measured, max_guesses):
        """ 
        This function evaluates the security loss after Bos et al attack
        :report_every: an integer that give the frequency of
        logging (None for no logging)
        :dbdd: instance of the class DBDD
        :measured: table representing the (simulated) information
        given by Bos et al attack
        :max_guesses: integer for upperbounding the number of guesses
        """

        Id = identity_matrix(n + m)
        for i in range(n):
            v = vec(Id[i])
            if report_every is not None and ((i % report_every == 0) or (i == n - 1)) :
                verbose = 2 
            else:
                verbose = 0
            dbdd.verbosity = verbose
            if verbose == 2:
                logging("[...%d]" % report_every, newline=False)
            if variance_aposteriori[measured[i]] is not None and variance_aposteriori[measured[i]] != 0:
                dbdd.integrate_approx_hint(v,
                                           center_aposteriori[measured[i]],
                                           variance_aposteriori[measured[i]],
                                           aposteriori=True, estimate=verbose)
            elif variance_aposteriori[measured[i]] is not None and variance_aposteriori[measured[i]] == 0 :
                dbdd.integrate_perfect_hint(v, center_aposteriori[measured[i]],
                                            estimate=verbose)
        if report_every is not None:
            dbdd.integrate_q_vectors(q, indices=range(n, n + m), report_every=report_every)
        else:
            dbdd.integrate_q_vectors(q, indices=range(n, n + m))
        (beta, _) = dbdd.estimate_attack()

        if report_every is not None:
            logging("     Hybrid attack estimation     ", style="HEADER")

        sorted_guesses = sorted(proba_best_guess_correct.items(),
                                key=lambda kv: - kv[1])
        sorted_guesses = [sorted_guesses[i][0] for i in range(len(sorted_guesses))
                          if sorted_guesses[i][1] != 1.]
        proba_success = 1.
        dbdd.verbosity = 0
        guesses = 0
        j = 0
        for i in ordered_indices(sorted_guesses, measured):
            j += 1
            if (guesses <= max_guesses):
                v = vec(Id[i])
                if dbdd.integrate_perfect_hint(v, _):
                    guesses += 1
                    proba_success *= proba_best_guess_correct[measured[i]]
                if report_every is not None and (j % report_every == 0):
                    logging("[...%d]" % report_every, newline=False)
                    dbdd.integrate_q_vectors(q, indices=range(n, n + m))
                    logging("dim=%3d \t delta=%.6f \t beta=%3.2f \t guesses=%4d" %
                            (dbdd.dim(), dbdd.delta, dbdd.beta, guesses),
                            style="VALUE", newline=False)
                    logging("Proba success = %s" % proba_success, style="VALUE",
                            newline=True)
        return beta, dbdd.beta, proba_success

    if demo:
        A, b, dbdd = initialize_from_LWE_instance(DBDD_predict_diag,
                                                  n,
                                                  q, m, D_s, D_s, verbosity=2)
        measured = [simu_measured(dbdd.u[0, i]) for i in range(n)]
        estimate_SCA(50, dbdd, measured, 200)
    else:
        """  Averaging
        The following averages the measures to get accurate data
        for the paper. The averaging mode is quite long.
        """
        nb_tests_per_params = 100

        # Chosen values for the number of guesses
        if params == 'CCS1':
            max_guesses = 100
        elif params == 'CCS2':
            max_guesses = 350
        elif params == 'CCS3':
            max_guesses = 300
        elif params == 'CCS4':
            max_guesses = 150
        elif params == 'NIST1':
            max_guesses = 50
        elif params == 'NIST2':
            max_guesses = 100

        beta = 0
        beta_hybrid = 0
        proba_success = 0
        for i in range(nb_tests_per_params):
            A, b, dbdd = initialize_from_LWE_instance(DBDD_predict_diag,
                                                      n,
                                                      q, m, D_s, D_s, verbosity=0)
            measured = [simu_measured(dbdd.u[0, i]) for i in range(n)]
            b, b_hybrid, p_success = estimate_SCA(None, dbdd,
                                                  measured, max_guesses)
            beta += b
            beta_hybrid += b_hybrid
            proba_success += p_success

        beta /= nb_tests_per_params
        beta_hybrid /= nb_tests_per_params
        proba_success /= nb_tests_per_params
        logging("Attack with hints: %3.2f bikz" % beta, style="HEADER")
        logging("Attack with hints and guess: %3.2f bikz"% beta_hybrid, style="HEADER")
        logging("Number of guesses: %4d" % max_guesses, style="HEADER")
        logging("Success probability: %3.2f" %proba_success, style="HEADER")