Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
validate_private_function(code, sig)
# Get nonreentrant lock
nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(sig, context.global_ctx)
# Create callback_ptr, this stores a destination in the bytecode for a private
# function to jump to after a function has executed.
clampers = []
# Allocate variable space.
context.memory_allocator.increase_memory(sig.max_copy_size)
_post_callback_ptr = f"{sig.name}_{sig.method_id}_post_callback_ptr"
context.callback_ptr = context.new_placeholder(typ=BaseType('uint256'))
clampers.append(
LLLnode.from_list(
['mstore', context.callback_ptr, 'pass'],
annotation='pop callback pointer',
)
)
if sig.total_default_args > 0:
clampers.append(['label', _post_callback_ptr])
# private functions without return types need to jump back to
# the calling function, as there is no return statement to handle the
# jump.
if sig.output_type is None:
stop_func = [['jump', ['mload', context.callback_ptr]]]
else:
stop_func = [['stop']]
# Generate copiers
def get_sig_statements(sig, pos):
method_id_node = LLLnode.from_list(sig.method_id, pos=pos, annotation='%s' % sig.sig)
if sig.private:
sig_compare = 0
private_label = LLLnode.from_list(
['label', 'priv_{}'.format(sig.method_id)],
pos=pos, annotation='%s' % sig.sig
)
else:
sig_compare = ['eq', ['mload', 0], method_id_node]
private_label = ['pass']
return sig_compare, private_label
def attribute(self):
# x.balance: balance of address x
if self.expr.attr == 'balance':
addr = Expr.parse_value_expr(self.expr.value, self.context)
if not is_base_type(addr.typ, 'address'):
raise TypeMismatchException(
"Type mismatch: balance keyword expects an address as input",
self.expr
)
return LLLnode.from_list(
['balance', addr],
typ=BaseType('uint256', {'wei': 1}),
location=None,
pos=getpos(self.expr),
)
# x.codesize: codesize of address x
elif self.expr.attr == 'codesize' or self.expr.attr == 'is_contract':
addr = Expr.parse_value_expr(self.expr.value, self.context)
if not is_base_type(addr.typ, 'address'):
raise TypeMismatchException(
"Type mismatch: codesize keyword expects an address as input",
self.expr,
)
if self.expr.attr == 'codesize':
eval_code = ['extcodesize', addr]
output_type = 'int128'
else:
if not isinstance(key, int):
raise TypeMismatchException(
"Expecting a static index; cannot access element %r" % key, pos
)
attrs = list(range(len(typ.members)))
index = key
annotation = None
if location == 'storage':
return LLLnode.from_list(
['add', ['sha3_32', parent], LLLnode.from_list(index, annotation=annotation)],
typ=subtype,
location='storage',
)
elif location == 'storage_prehashed':
return LLLnode.from_list(
['add', parent, LLLnode.from_list(index, annotation=annotation)],
typ=subtype,
location='storage',
)
elif location in ('calldata', 'memory'):
offset = 0
for i in range(index):
offset += 32 * get_size_of_type(typ.members[attrs[i]])
return LLLnode.from_list(['add', offset, parent],
typ=typ.members[key],
location=location,
annotation=annotation)
else:
raise TypeMismatchException("Not expecting a member variable access", pos)
elif isinstance(typ, MappingType):
return LLLnode.from_list(['with', '_L', left, ['seq'] + subs], typ=None)
# If the right side is a null
# CC 20190619 probably not needed as of #1106
elif isinstance(right.typ, NullType):
subs = []
for i in range(left.typ.count):
subs.append(make_setter(add_variable_offset(
left_token,
LLLnode.from_list(i, typ='int128'),
pos=pos,
array_bounds_check=False,
), LLLnode.from_list(None, typ=NullType()), location, pos=pos))
return LLLnode.from_list(['with', '_L', left, ['seq'] + subs], typ=None)
# If the right side is a variable
else:
right_token = LLLnode.from_list('_R', typ=right.typ, location=right.location)
subs = []
for i in range(left.typ.count):
subs.append(make_setter(add_variable_offset(
left_token,
LLLnode.from_list(i, typ='int128'),
pos=pos,
array_bounds_check=False,
), add_variable_offset(
right_token,
LLLnode.from_list(i, typ='int128'),
pos=pos,
array_bounds_check=False,
), location, pos=pos))
return LLLnode.from_list([
'with', '_L', left, [
'with', '_R', right, ['seq'] + subs]
if isinstance(typ, BaseType):
if isinstance(arg, LLLnode):
value = unwrap_location(arg)
else:
value = Expr(arg, context).lll_node
value = base_type_conversion(value, value.typ, typ, pos)
holder.append(LLLnode.from_list(['mstore', placeholder, value], typ=typ, location='memory'))
elif isinstance(typ, ByteArrayLike):
if isinstance(arg, LLLnode): # Is prealloacted variable.
source_lll = arg
else:
source_lll = Expr(arg, context).lll_node
# Set static offset, in arg slot.
holder.append(LLLnode.from_list(['mstore', placeholder, ['mload', dynamic_offset_counter]]))
# Get the biginning to write the ByteArray to.
dest_placeholder = LLLnode.from_list(
['add', datamem_start, ['mload', dynamic_offset_counter]],
typ=typ, location='memory', annotation="pack_args_by_32:dest_placeholder")
copier = make_byte_array_copier(dest_placeholder, source_lll, pos=pos)
holder.append(copier)
item_maxlen = source_lll.typ.maxlen
# Add zero padding.
holder.append(
zero_pad(dest_placeholder, item_maxlen, zero_pad_i=zero_pad_i)
)
# Increment offset counter.
increment_counter = LLLnode.from_list([
'mstore', dynamic_offset_counter,
[
['mload', MemoryPositions.MAXNUM],
])
elif new_typ.typ == 'decimal':
p.append([
'clamp',
['mload', MemoryPositions.MINDECIMAL],
arith,
['mload', MemoryPositions.MAXDECIMAL],
])
elif new_typ.typ == 'uint256':
p.append(arith)
else:
raise Exception(f"{arith} {new_typ}")
p = ['with', 'l', left, ['with', 'r', right, p]]
return LLLnode.from_list(p, typ=new_typ, pos=pos)
raise InvalidLiteralException("Number out of range: " + str(orig.value), pos)
# Special Case: Literals in function calls should always convey unit type as well.
if in_function_call and not (frm.unit == to.unit and frm.positional == to.positional):
raise InvalidLiteralException(
"Function calls require explicit unit definitions on calls, expected %r" % to, pos
)
if not isinstance(frm, (BaseType, NullType)) or not isinstance(to, BaseType):
raise TypeMismatchException(
"Base type conversion from or to non-base type: %r %r" % (frm, to), pos
)
elif is_base_type(frm, to.typ) and are_units_compatible(frm, to):
return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate)
elif isinstance(frm, ContractType) and to == BaseType('address'):
return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate)
elif is_valid_int128_to_decimal:
return LLLnode.from_list(
['mul', orig, DECIMAL_DIVISOR],
typ=BaseType('decimal', to.unit, to.positional),
)
elif isinstance(frm, NullType):
if to.typ not in ('int128', 'bool', 'uint256', 'address', 'bytes32', 'decimal'):
# This is only to future proof the use of base_type_conversion.
raise TypeMismatchException( # pragma: no cover
"Cannot convert null-type object to type %r" % to, pos
)
return LLLnode.from_list(0, typ=to)
# Integer literal conversion.
elif (frm.typ, to.typ, frm.is_literal) == ('int128', 'uint256', True):
return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate)
else:
raise TypeMismatchException(
"Typecasting from base type %r to %r unavailable" % (frm, to), pos
def constants(self):
if self.expr.value is True:
return LLLnode.from_list(
1,
typ=BaseType('bool', is_literal=True),
pos=getpos(self.expr),
)
elif self.expr.value is False:
return LLLnode.from_list(
0,
typ=BaseType('bool', is_literal=True),
pos=getpos(self.expr),
)
elif self.expr.value is None:
return LLLnode.from_list(None, typ=NullType(), pos=getpos(self.expr))
else:
raise Exception(f"Unknown name constant: {self.expr.value.value}")
else:
_clampers = clampers
# Function with default parameters.
o = LLLnode.from_list(
['seq',
sig_chain,
['if', 0, # can only be jumped into
['seq',
['label', function_routine] if not sig.private else ['pass'],
['seq'] + _clampers + [parse_body(c, context) for c in code.body] + stop_func]]], typ=None, pos=getpos(code))
else:
# Function without default parameters.
sig_compare, private_label = get_sig_statements(sig, getpos(code))
o = LLLnode.from_list(
['if',
sig_compare,
['seq'] + [private_label] + clampers + [parse_body(c, context) for c in code.body] + stop_func], typ=None, pos=getpos(code))
# Check for at leasts one return statement if necessary.
if context.return_type and context.function_return_count == 0:
raise FunctionDeclarationException(
"Missing return statement in function '%s' " % sig.name, code
)
o.context = context
o.total_gas = o.gas + calc_mem_gas(o.context.next_mem)
o.func_name = sig.name
return o