Fork of the official github repository of the framework Leaky-LWE-Estimator, a Sage Toolkit to attack and estimate the hardness of LWE with Side Information. https://github.com/lducas/leaky-LWE-Estimator
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

geometry.sage 10KB

  1. from numpy.linalg import inv as np_inv
  2. # from numpy.linalg import slogdet as np_slogdet
  3. from numpy import array
  4. import numpy as np
  5. def dual_basis(B):
  6. """
  7. Compute the dual basis of B
  8. """
  9. return B.pseudoinverse().transpose()
  10. def projection_matrix(A):
  11. """
  12. Construct the projection matrix orthogonally to Span(V)
  13. """
  14. S = A * A.T
  15. return A.T * S.inverse() * A
  16. def project_against(v, X):
  17. """ Project matrix X orthonally to vector v"""
  18. # Pv = projection_matrix(v)
  19. # return X - X * Pv
  20. Z = (X * v.T) * v / scal(v * v.T)
  21. return X - Z
  22. # def make_primitive(B, v):
  23. # assert False
  24. # # project and Scale v's in V so that each v
  25. # # is in the lattice, and primitive in it.
  26. # # Note: does not make V primitive as as set of vector !
  27. # # (e.g. linear dep. not eliminated)
  28. # PB = projection_matrix(B)
  29. # DT = dual_basis(B).T
  30. # v = vec(v) * PB
  31. # w = v * DT
  32. # den = lcm([x.denominator() for x in w[0]])
  33. # num = gcd([x for x in w[0] * den])
  34. # if num==0:
  35. # return None
  36. # v *= den/num
  37. # return v
  38. def vol(B):
  39. return sqrt(det(B * B.T))
  40. def project_and_eliminate_dep(B, W):
  41. # Project v on Span(B)
  42. PB = projection_matrix(B)
  43. V = W * PB
  44. rank_loss = V.nrows() - V.rank()
  45. if rank_loss > 0:
  46. print("WARNING: their were %d linear dependencies out of %d " %
  47. (rank_loss, V.nrows()))
  48. V = V.LLL()
  49. V = V[rank_loss:]
  50. return V
  51. def is_cannonical_direction(v):
  52. v = vec(v)
  53. return sum([x != 0 for x in v[0]]) == 1
  54. def cannonical_param(v):
  55. v = vec(v)
  56. assert is_cannonical_direction(v)
  57. i = [x != 0 for x in v[0]].index(True)
  58. return i, v[0, i]
  59. def xgcd_of_list(a):
  60. (g, s, t) = xgcd(a[0], a[1])
  61. bezouts = [s, t]
  62. for v in a[2:]:
  63. (g, s_, t_) = xgcd(g, v)
  64. bezouts = [b*s_ for b in bezouts]
  65. bezouts.append(t_)
  66. return (g, bezouts)
  67. def remove_linear_dependencies(B, dim=None):
  68. nrows = B.nrows()
  69. if dim is None or nrows > dim:
  70. # Determine the number of dependencies
  71. K, r = None, None
  72. if dim is None:
  73. K = B.left_kernel().basis_matrix() # I assume that the cost of "left_kernel" is negligeable before "LLL"
  74. r = K.dimensions()[0]
  75. else:
  76. r = nrows-dim
  77. if r == 1 and False:
  78. print("Use Better Algo")
  79. # Find a linear dependency
  80. if K is None:
  81. K = B.left_kernel().basis_matrix()
  82. assert K.dimensions()[0] == 1
  83. combinaison = K[0]
  84. combinaison *= lcm([v.denominator() for v in combinaison])
  85. print(combinaison)
  86. pivot, pivot_value = None, None
  87. for ind, value in enumerate(combinaison):
  88. if abs(value) == 1:
  89. pivot, pivot_value = ind, value
  90. break
  91. if pivot_value is None:
  92. print('Complex case')
  93. for ind, value in enumerate(combinaison):
  94. if abs(value) > 0 and gcd([v for i,v in enumerate(combinaison) if i!=ind]) == 1:
  95. if pivot is None or abs(value)<abs(pivot_value):
  96. pivot, pivot_value = ind, value
  97. if pivot_value < 0:
  98. combinaison = vector(ZZ, [-v for v in combinaison])
  99. _, bezouts = xgcd_of_list([v for i,v in enumerate(combinaison) if i!=pivot])
  100. factor = combinaison[pivot]-1
  101. for i in range(len(combinaison)):
  102. ind = i if i < pivot else i-1
  103. if i != pivot:
  104. B[i] += factor*bezouts[ind]*B[pivot]
  105. combinaison[pivot] += -factor
  106. assert (combinaison*B).is_zero(), 'It is not a linear dependency anymore !'
  107. assert abs(combinaison[pivot]) == 1, f'Error abs(combinaison[pivot]) == {abs(combinaison[pivot])} != 1'
  108. B = B[[i for i in range(B.dimensions()[0]) if i != pivot]]
  109. else:
  110. B = B.LLL()
  111. B = B[r:]
  112. return B
  113. def lattice_orthogonal_section(D, V, maintains_basis=True):
  114. """
  115. Compute the intersection of the lattice L(B)
  116. with the hyperplane orthogonal to Span(V).
  117. (V can be either a vector or a matrix)
  118. INPUT AND OUTPUT DUAL BASIS
  119. Algorithm:
  120. - project V onto Span(B)
  121. - project the dual basis onto orth(V)
  122. - eliminate linear dependencies (LLL)
  123. - go back to the primal.
  124. """
  125. #V = project_and_eliminate_dep(D, V) ## No need because D is full-rank
  126. #r = V.nrows()
  127. # Project the dual basis orthogonally to v
  128. PV = projection_matrix(V)
  129. D = D - D * PV
  130. # Eliminate linear dependencies
  131. if maintains_basis:
  132. D = remove_linear_dependencies(D)
  133. # Go back to the primal
  134. return D
  135. def lattice_project_against(B, V, maintains_basis=True):
  136. """
  137. Compute the projection of the lattice L(B) orthogonally to Span(V). All vectors if V
  138. (or at least their projection on Span(B)) must belong to L(B).
  139. Algorithm:
  140. - project V onto Span(B)
  141. - project the basis onto orth(V)
  142. - eliminate linear dependencies (LLL)
  143. """
  144. # Project v on Span(B)
  145. #V = project_and_eliminate_dep(B, V) ## No need because D is full-rank
  146. #r = V.nrows() # Useless
  147. # Check that V belogs to L(B)
  148. D = dual_basis(B)
  149. M = D * V.T
  150. if not lcm([x.denominator() for x in M.list()]) == 1:
  151. raise ValueError("Not in the lattice")
  152. # Project the basis orthogonally to v
  153. PV = projection_matrix(V)
  154. B = B - B * PV
  155. # Eliminate linear dependencies
  156. if maintains_basis:
  157. B = remove_linear_dependencies(B)
  158. # Go back to the primal
  159. return B
  160. def lattice_modular_intersection(D, V, k, maintains_basis=True):
  161. """
  162. Compute the intersection of the lattice L(B) with
  163. the lattice {x | x*V = 0 mod k}
  164. (V can be either a vector or a matrix)
  165. Algorithm:
  166. - project V onto Span(B)
  167. - append the equations in the dual
  168. - eliminate linear dependencies (LLL)
  169. - go back to the primal.
  170. """
  171. # Project v on Span(B)
  172. #V = project_and_eliminate_dep(D, V) ## No need because D is full-rank
  173. #r = V.nrows() # Useless
  174. # append the equation in the dual
  175. V /= k
  176. D = D.stack(V)
  177. # Eliminate linear dependencies
  178. if maintains_basis:
  179. D = remove_linear_dependencies(D)
  180. # Go back to the primal
  181. return D
  182. def is_diagonal(M):
  183. if M.nrows() != M.ncols():
  184. return False
  185. A = M.numpy()
  186. return np.all(A == np.diag(np.diagonal(A)))
  187. def logdet(M, exact=False):
  188. """
  189. Compute the log of the determinant of a large rational matrix,
  190. tryping to avoid overflows.
  191. """
  192. if not exact:
  193. MM = array(M, dtype=float)
  194. _, l = slogdet(MM)
  195. return l
  196. a = abs(M.det())
  197. l = 0
  198. while a > 2**32:
  199. l += RR(32 * ln(2))
  200. a /= 2**32
  201. l += ln(RR(a))
  202. return l
  203. def degen_inverse(S, B=None):
  204. """ Compute the inverse of a symmetric matrix restricted
  205. to its span
  206. """
  207. # Get an orthogonal basis for the Span of B
  208. if B is None:
  209. # Get an orthogonal basis for the Span of B
  210. V = S.echelon_form()
  211. V = V[:V.rank()]
  212. P = projection_matrix(V)
  213. else:
  214. P = projection_matrix(B)
  215. # make S non-degenerated by adding the complement of span(B)
  216. C = identity_matrix(S.ncols()) - P
  217. Sinv = (S + C).inverse() - C
  218. assert S * Sinv == P, "Consistency failed (probably not your fault)."
  219. assert P * Sinv == Sinv, "Consistency failed (probably not your fault)."
  220. return Sinv
  221. def degen_logdet(S, B=None):
  222. """ Compute the determinant of a symmetric matrix
  223. sigma (m x m) restricted to the span of the full-rank
  224. rectangular (k x m, k <= m) matrix V
  225. """
  226. # Get an orthogonal basis for the Span of B
  227. if B is None:
  228. # Get an orthogonal basis for the Span of B
  229. V = S.echelon_form()
  230. V = V[:V.rank()]
  231. P = projection_matrix(V)
  232. else:
  233. P = projection_matrix(B)
  234. # Check that S is indeed supported by span(B)
  235. #assert S == P.T * S * P
  236. assert (S - P.T * S * P).norm() <= 1e-10
  237. # make S non-degenerated by adding the complement of span(B)
  238. C = identity_matrix(S.ncols()) - P
  239. l3 = logdet(S + C)
  240. return l3
  241. def square_root_inverse_degen(S, B=None):
  242. """ Compute the determinant of a symmetric matrix
  243. sigma (m x m) restricted to the span of the full-rank
  244. rectangular (k x m, k <= m) matrix V
  245. """
  246. if B is None:
  247. # Get an orthogonal basis for the Span of B
  248. V = S.echelon_form()
  249. V = V[:V.rank()]
  250. P = projection_matrix(V)
  251. else:
  252. P = projection_matrix(B)
  253. # make S non-degenerated by adding the complement of span(B)
  254. C = identity_matrix(S.ncols()) - P
  255. S_inv = np_inv(array((S + C), dtype=float))
  256. S_inv = array(S_inv, dtype=float)
  257. L_inv = cholesky(S_inv)
  258. L_inv = round_matrix_to_rational(L_inv)
  259. L = L_inv.inverse()
  260. return L, L_inv
  261. def build_standard_substitution_matrix(V, pivot=None, data=None, output_data=False):
  262. _, pivot = V.nonzero_positions()[0] if (pivot is None) else (None, pivot)
  263. assert V[0,pivot] != 0, 'The value of the pivot must be non-zero.'
  264. dim = V.ncols()
  265. V1 = - V[0,:pivot] / V[0,pivot]
  266. V2 = - V[0,pivot+1:] / V[0,pivot]
  267. Gamma = zero_matrix(QQ, dim,dim-1)
  268. Gamma[:pivot,:pivot] = identity_matrix(pivot)
  269. Gamma[pivot,:pivot] = V1
  270. Gamma[pivot,pivot:] = V2
  271. Gamma[pivot+1:,pivot:] = identity_matrix(dim-pivot-1)
  272. data = data if not output_data else {}
  273. if data is not None:
  274. norm2 = 1 + scal(V1*V1.T) + scal(V2*V2.T)
  275. normalization_matrix = zero_matrix(QQ, dim-1,dim-1)
  276. normalization_matrix[:pivot,:pivot] = identity_matrix(pivot) - V1.T*V1 / norm2
  277. normalization_matrix[pivot:,pivot:] = identity_matrix(dim-pivot-1) - V2.T*V2 / norm2
  278. normalization_matrix[:pivot,pivot:] = - V1.T*V2 / norm2
  279. normalization_matrix[pivot:,:pivot] = - V2.T*V1 / norm2
  280. data['det'] = norm2
  281. data['normalization_matrix'] = normalization_matrix
  282. if output_data:
  283. return Gamma, data
  284. else:
  285. return Gamma