Fork of the official github repository of the framework Leaky-LWE-Estimator, a Sage Toolkit to attack and estimate the hardness of LWE with Side Information. https://github.com/lducas/leaky-LWE-Estimator
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

prediction_verifications.sage 5.2KB

  1. from multiprocessing import Pool
  2. from map_drop import map_drop
  3. from numpy.random import seed as np_seed
  4. load("../framework/instance_gen.sage")
  5. Derr = build_centered_binomial_law(6)
  6. modulus = 11
  7. saving_results = True
  8. results_filename = "results/results.csv"
  9. try:
  10. N_tests = int(sys.argv[1])
  11. threads = int(sys.argv[2])
  12. except:
  13. N_tests = 5
  14. threads = 1
  15. def v(i):
  16. return canonical_vec(d, i)
  17. qvec_donttouch = 20
  18. def randv():
  19. vv = v(randint(qvec_donttouch, d - 1))
  20. vv -= v(randint(qvec_donttouch, d - 1))
  21. vv += v(randint(qvec_donttouch, d - 1))
  22. vv -= v(randint(qvec_donttouch, d - 1))
  23. vv += v(randint(qvec_donttouch, d - 1))
  24. return vv
  25. def qrandv():
  26. vv = randint(1, q-1) * v(randint(qvec_donttouch, d - 1))
  27. vv -= randint(1, q-1) * v(randint(qvec_donttouch, d - 1))
  28. vv += randint(1, q-1) * v(randint(qvec_donttouch, d - 1))
  29. vv -= randint(1, q-1) * v(randint(qvec_donttouch, d - 1))
  30. vv += randint(1, q-1) * v(randint(qvec_donttouch, d - 1))
  31. return vv
  32. def one_experiment(id, aargs):
  33. (N_hints, T_hints) = aargs
  34. mu, variance = average_variance(Derr)
  35. set_random_seed(id)
  36. np_seed(seed=id)
  37. A, b, dbdd = initialize_from_LWE_instance(DBDD, n, q,
  38. m, D_e, D_s,
  39. verbosity=0)
  40. A, b, dbdd_p = initialize_from_LWE_instance(DBDD_predict,
  41. n, q, m, D_e,
  42. D_s,
  43. verbosity=0)
  44. for j in range(N_hints):
  45. vv = randv()
  46. if T_hints == "Perfect":
  47. dbdd.integrate_perfect_hint(vv, dbdd.leak(vv), estimate=False)
  48. dbdd_p.integrate_perfect_hint(vv, dbdd_p.leak(vv), estimate=False)
  49. if T_hints == "Approx":
  50. dbdd.integrate_approx_hint(vv, dbdd.leak(vv) +
  51. draw_from_distribution(Derr),
  52. variance, estimate=False)
  53. dbdd_p.integrate_approx_hint(vv, dbdd_p.leak(vv) +
  54. draw_from_distribution(Derr),
  55. variance, estimate=False)
  56. if T_hints == "Modular":
  57. dbdd.integrate_modular_hint(vv, dbdd.leak(vv) % modulus,
  58. modulus, smooth=True, estimate=False)
  59. dbdd_p.integrate_modular_hint(vv, dbdd_p.leak(vv) % modulus,
  60. modulus, smooth=True, estimate=False)
  61. if T_hints == "Q-Modular":
  62. vv = qrandv()
  63. dbdd.integrate_q_modular_hint(vv, dbdd.leak(vv) % q,
  64. q, estimate=False)
  65. dbdd_p.integrate_q_modular_hint(vv, dbdd_p.leak(vv) % q,
  66. q, estimate=False)
  67. dbdd_p.integrate_q_vectors(q, indices=range(20))
  68. dbdd.integrate_q_vectors(q, indices=range(20))
  69. beta_pred_light, _ = dbdd_p.estimate_attack(probabilistic=True)
  70. beta_pred_full, _ = dbdd.estimate_attack(probabilistic=True)
  71. beta, _ = dbdd.attack()
  72. return (beta, beta_pred_full, beta_pred_light)
  73. dic = {" ": None}
  74. def get_stats(data, N_tests):
  75. avg = RR(sum(data)) / N_tests
  76. var = abs(RR(sum([r**2 for r in data])) / N_tests - avg**2)
  77. return (avg, var)
  78. def save_results(*args):
  79. with open(results_filename, 'a') as _file:
  80. _file.write(';'.join([str(arg) for arg in args])+'\n')
  81. def validation_prediction(N_tests, N_hints, T_hints):
  82. # Estimation
  83. import datetime
  84. ttt = datetime.datetime.now()
  85. res = map_drop(N_tests, threads, one_experiment, (N_hints, T_hints))
  86. beta_real, vbeta_real = get_stats([r[0] for r in res], N_tests)
  87. beta_pred_full, vbeta_pred_full = get_stats([r[1] for r in res], N_tests)
  88. beta_pred_light, vbeta_pred_light = get_stats([r[2] for r in res], N_tests)
  89. print("%d,\t %.3f,\t %.3f,\t %.3f \t\t" %
  90. (N_hints, beta_real, beta_pred_full, beta_pred_light), end=" \t")
  91. print("Time:", datetime.datetime.now() - ttt)
  92. if saving_results:
  93. save_results(
  94. T_hints, N_hints, N_tests, datetime.datetime.now() - ttt,
  95. beta_real, beta_pred_full, beta_pred_light,
  96. vbeta_real, vbeta_pred_full, vbeta_pred_light,
  97. )
  98. return beta_pred_full
  99. logging("Number of threads : %d" % threads, style="DATA")
  100. logging("Number of Samples : %d" % N_tests, style="DATA")
  101. logging(" Validation tests ", style="HEADER")
  102. n = 70
  103. m = n
  104. q = 3301
  105. D_s = build_centered_binomial_law(40)
  106. D_e = build_centered_binomial_law(40)
  107. d = m + n
  108. print("\n \n None")
  109. print("hints,\t real,\t pred_full, \t pred_light,")
  110. beta_pred = validation_prediction(N_tests, 0, "None")
  111. for T_hints in ["Perfect", "Modular", "Q-Modular", "Approx"]:
  112. hint_range = None
  113. if T_hints == "Perfect":
  114. hint_range = range(1, 100)
  115. elif T_hints == "Modular":
  116. hint_range = range(2, 200, 2)
  117. elif T_hints == "Q-Modular":
  118. hint_range = range(1, 100)
  119. elif T_hints == "Approx":
  120. hint_range = range(4, 200, 4)
  121. print(f"\n \n {T_hints}")
  122. print("hints,\t real,\t pred_full, \t pred_light,")
  123. for h in hint_range:
  124. beta_pred = validation_prediction(N_tests, h, T_hints) # Line 0
  125. if beta_pred < 3:
  126. break