Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _g(self, word, key_expand_round):
"""
One-byte left circular rotation, substitution of each byte
"""
import numbers
self._build_memories_if_not_exists()
a = libutils.partition_wire(word, 8)
sub = [self.sbox[a[index]] for index in (3, 0, 1, 2)]
if isinstance(key_expand_round, numbers.Number):
rcon_data = self._rcon_data[key_expand_round + 1] # int value
else:
rcon_data = self.rcon[key_expand_round + 1]
sub[3] = sub[3] ^ rcon_data
return pyrtl.concat_list(sub)
def _mix_col_subgroup(self, in_vector, gm_multipliers):
def _mix_single(index):
mult_items = [self._galois_mult(a[(index + loc) % 4], mult_table)
for loc, mult_table in enumerate(gm_multipliers)]
return mult_items[0] ^ mult_items[1] ^ mult_items[2] ^ mult_items[3]
a = libutils.partition_wire(in_vector, 8)
return pyrtl.concat_list([_mix_single(index) for index in range(len(a))])
def _sub_bytes(self, in_vector, inverse=False):
self._build_memories_if_not_exists()
subbed = [self.inv_sbox[byte] if inverse else self.sbox[byte]
for byte in libutils.partition_wire(in_vector, 8)]
return pyrtl.concat_list(subbed)
def _key_expansion(self, old_key, key_expand_round):
self._build_memories_if_not_exists()
w = libutils.partition_wire(old_key, 32)
x = [w[3] ^ self._g(w[0], key_expand_round)]
x.insert(0, x[0] ^ w[2])
x.insert(0, x[0] ^ w[1])
x.insert(0, x[0] ^ w[0])
return pyrtl.concat_list(x)
def _mix_columns(self, in_vector, inverse=False):
self._build_memories_if_not_exists()
igm_mults = [14, 9, 13, 11] if inverse else [2, 1, 1, 3]
subgroups = libutils.partition_wire(in_vector, 32)
return pyrtl.concat_list([self._mix_col_subgroup(sg, igm_mults) for sg in subgroups])
bitwidth = len(wire_array_2)
add_wires = [], []
result = []
for single_w_index in range(bitwidth):
if len(wire_array_2[single_w_index]) == 2: # Check if the two wire vectors overlap yet
break
result.append(wire_array_2[single_w_index][0])
for w_loc in range(single_w_index, bitwidth):
for i in range(2):
if len(wire_array_2[w_loc]) >= i + 1:
add_wires[i].append(wire_array_2[w_loc][i])
else:
add_wires[i].append(pyrtl.Const(0))
adder_result = adder(pyrtl.concat_list(add_wires[0]), pyrtl.concat_list(add_wires[1]))
return pyrtl.concat(adder_result, *reversed(result))
def _inv_shift_rows(in_vector):
a = libutils.partition_wire(in_vector, 8)
return pyrtl.concat_list((a[12], a[9], a[6], a[3],
a[0], a[13], a[10], a[7],
a[4], a[1], a[14], a[11],
a[8], a[5], a[2], a[15]))