Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@export
def write_tee(output):
global tee_file
if tee_file:
click.echo(output, file=tee_file, nl=False)
click.echo("\n", file=tee_file, nl=False)
tee_file.flush()
@export
def set_favorite_queries(config):
global favoritequeries
favoritequeries = FavoriteQueries(config)
@export
def close_tee():
global tee_file
if tee_file:
tee_file.close()
tee_file = None
@export
def get_editor_query(sql):
"""Get the query part of an editor command."""
sql = sql.strip()
# The reason we can't simply do .strip('\e') is that it strips characters,
# not a substring. So it'll strip "e" in the end of the sql also!
# Ex: "select * from style\e" -> "select * from styl".
pattern = re.compile("(^\\\e|\\\e$)")
while pattern.search(sql):
sql = pattern.sub("", sql)
return sql
@export
def execute(cur, sql):
"""Execute a special command and return the results. If the special command
is not supported a KeyError will be raised.
"""
command, verbose, arg = parse_special_command(sql)
if (command not in COMMANDS) and (command.lower() not in COMMANDS):
raise CommandNotFound
try:
special_cmd = COMMANDS[command]
except KeyError:
special_cmd = COMMANDS[command.lower()]
if special_cmd.case_sensitive:
raise CommandNotFound("Command not found: %s" % command)
@export
@special_command(
"pager",
"\\P [command]",
"Set PAGER. Print the query results via PAGER.",
arg_type=PARSED_QUERY,
aliases=("\\P",),
case_sensitive=True,
)
def set_pager(arg, **_):
if arg:
os.environ["PAGER"] = arg
msg = "PAGER set to %s." % arg
set_pager_enabled(True)
else:
if "PAGER" in os.environ:
msg = "PAGER set to %s." % os.environ["PAGER"]
@export
def editor_command(command):
"""
Is this an external editor command?
:param command: string
"""
# It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
# for both conditions.
return command.strip().endswith("\\e") or command.strip().startswith("\\e")
@export
def register_special_command(
handler,
command,
shortcut,
description,
arg_type=PARSED_QUERY,
hidden=False,
case_sensitive=False,
aliases=(),
):
cmd = command.lower() if not case_sensitive else command
COMMANDS[cmd] = SpecialCommand(
handler, command, shortcut, description, arg_type, hidden, case_sensitive
)
for alias in aliases:
cmd = alias.lower() if not case_sensitive else alias
@export
def get_filename(sql):
if sql.strip().startswith("\\e"):
command, _, filename = sql.partition(" ")
return filename.strip() or None
@export
@special_command(
"nopager",
"\\n",
"Disable pager, print to stdout.",
arg_type=NO_QUERY,
aliases=("\\n",),
case_sensitive=True,
)
def disable_pager():
set_pager_enabled(False)
return [(None, None, None, "Pager disabled.")]