Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _sql_rank(col, na_option, func_name, partition = False):
rank_func = getattr(sql.func, func_name)
if na_option == "keep":
return _sql_rank_over(rank_func, col, partition = partition)
warn_arg_default(func_name, 'na_option', None, "keep")
return RankOver(rank_func(), order_by = col)
def _nth_sql(x, n, order_by = None, default = None) -> ClauseElement:
if default is not None:
raise NotImplementedError("default argument not implemented")
if n < 0 and order_by is not None:
# e.g. -1 in python is 0, -2 is 1
n = abs(n + 1)
order_by = order_by.desc()
return RankOver(sql.func.nth_value(x, n + 1), order_by = order_by)
def _sql_rank_over(rank_func, col, partition):
# partitioning ensures aggregates that use total length are correct
# e.g. percent rank, cume_dist and friends
over_clause = RankOver(
rank_func(),
order_by = col,
partition_by = col.isnot(None) if partition else None
)
return sql.case({col.isnot(None): over_clause})
return lambda col: RankOver(sa_func(), order_by = col)