diff --git a/changelog.md b/changelog.md index 96154b07..fa1b334f 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,11 @@ Upcoming (TBD) ============== +Features +--------- +* Silently accept forward slash to introduce special commands. + + Internal -------- * Add test coverage for `client_commands.py`. diff --git a/mycli/clibuffer.py b/mycli/clibuffer.py index edbc64cb..c3f73856 100644 --- a/mycli/clibuffer.py +++ b/mycli/clibuffer.py @@ -3,7 +3,10 @@ from prompt_toolkit.filters import Condition, Filter from mycli.packages.special import iocommands -from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.packages.special.main import ( + CASE_INSENSITIVE_COMMANDS, + CASE_SENSITIVE_COMMANDS, +) def cli_is_multiline(mycli) -> Filter: @@ -26,12 +29,13 @@ def _multiline_exception(text: str) -> bool: # Multi-statement favorite query is a special case. Because there will # be a semicolon separating statements, we can't consider semicolon an # EOL. Let's consider an empty line an EOL instead. - if first_word.startswith("\\fs"): + if first_word.startswith(("\\fs", "/fs")): return orig.endswith("\n") return ( # Special Command first_word.startswith("\\") + or (first_word.startswith('/') and not first_word.startswith('/*')) or text.endswith(( # Ended with the current delimiter (usually a semi-column) iocommands.get_current_delimiter(), @@ -44,10 +48,10 @@ def _multiline_exception(text: str) -> bool: )) or # non-backslashed special commands such as "exit" or "help" don't need semicolon - first_word in SPECIAL_COMMANDS + first_word in CASE_SENSITIVE_COMMANDS or # uppercase variants accepted - first_word.lower() in SPECIAL_COMMANDS + first_word.lower() in CASE_INSENSITIVE_COMMANDS or # just a plain enter without any text (first_word == "") diff --git a/mycli/main_modes/repl.py b/mycli/main_modes/repl.py index 416b25c4..1c3d2247 100644 --- a/mycli/main_modes/repl.py +++ b/mycli/main_modes/repl.py @@ -108,7 +108,8 @@ def complete_while_typing_filter() -> bool: last_word = text[-MIN_COMPLETION_TRIGGER:] if len(last_word) == text_len: return text_len >= MIN_COMPLETION_TRIGGER - if text[:6].lower() in ['source', r'\.']: + # does \. make sense with text[:6] ? + if text[:6].lower() in ['source', r'\.', '/.']: # Different word characters for paths; see comment below. # In fact, it might be nice if paths had a different threshold. return not bool(re.search(r'[\s!-,:-@\[-^\{\}-]', last_word)) diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index f623a38c..a1126887 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -751,7 +751,7 @@ def suggest_type(full_text: str, text_before_cursor: str) -> list[dict[str, Any] # but the statement won't have a first token tok1 = statement.token_first() # lenient because \. will parse as two tokens - if tok1 and tok1.value.startswith('\\'): + if tok1 and tok1.value.startswith(('\\', '/')) and not tok1.value.startswith('/*'): return suggest_special(text_before_cursor) elif tok1: if tok1.value.lower() in SPECIAL_COMMANDS: @@ -771,22 +771,22 @@ def suggest_special(text: str) -> list[dict[str, Any]]: # Trying to complete the special command itself return [{"type": "special"}] - if cmd in ("\\u", "\\r"): + if cmd in ("\\u", "/u", "\\r", "/r"): return [{"type": "database"}] - if cmd.lower() in ('use', 'connect'): + if cmd.lower() in ('use', '/use', 'connect', '/connect'): return [{'type': 'database'}] - if cmd in (r'\T', r'\Tr'): + if cmd in (r'\T', '/T', r'\Tr', '/Tr'): return [{"type": "table_format"}] - if cmd.lower() in ('tableformat', 'redirectformat'): + if cmd.lower() in ('tableformat', '/tableformat', 'redirectformat', '/redirectformat'): return [{"type": "table_format"}] - if cmd in ["\\f", "\\fs", "\\fd"]: + if cmd in ["\\f", "/f", "\\fs", "/fs", "\\fd", "/fd"]: return [{"type": "favoritequery"}] - if cmd in ["\\dt", "\\dt+"]: + if cmd in ["\\dt", "/dt", "\\dt+", "/dt+"]: return [ {"type": "table", "schema": []}, {"type": "view", "schema": []}, @@ -794,19 +794,26 @@ def suggest_special(text: str) -> list[dict[str, Any]]: ] elif cmd.lower() in [ r'\.', + r'/.', 'source', + '/source', r'\o', + '/o', r'\once', - r'tee', + '/once', + 'tee', + '/tee', ]: return [{"type": "file_name"}] # todo: why is \edit case-sensitive? elif cmd in [ r'\e', + '/e', r'\edit', + '/edit', ]: return [{"type": "file_name"}] - if cmd in ["\\llm", "\\ai"]: + if cmd in ["\\llm", "/llm", "\\ai", "/ai"]: return [{"type": "llm"}] return [{"type": "keyword"}, {"type": "special"}] diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py index 2a29c7cf..c8411279 100644 --- a/mycli/packages/special/iocommands.py +++ b/mycli/packages/special/iocommands.py @@ -205,20 +205,19 @@ def editor_command(command: str) -> bool: :param command: string """ # special case: allow help on the \edit command - if re.match(r'^([Hh][Ee][Ll][Pp])\s+(\\e|\\edit)\s*(;|\\G|\\g)?\s*$', command): + if re.match(r'^/?([Hh][Ee][Ll][Pp])\s+(\\e|\\edit|/e|/edit)\s*(;|\\G|\\g)?\s*$', command): return False # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check # for both conditions. return ( - command.strip().endswith("\\e") - or command.strip().startswith("\\e ") - or command.strip().endswith("\\edit") - or command.strip().startswith("\\edit ") + command.strip().endswith(("\\e", "\\edit")) + or command.strip().startswith(("\\e ", "/e ", "\\edit ", "/edit ")) + or command.strip() in (("\\e", "/e", "\\edit", "/edit")) ) def get_filename(sql: str) -> str | None: - if sql.strip().startswith("\\e ") or sql.strip().startswith("\\edit "): + if sql.strip().startswith(("\\e ", "/e ")) or sql.strip().startswith(("\\edit ", "/edit ")): command, _, filename = sql.partition(" ") return filename.strip() or None else: @@ -229,6 +228,9 @@ def get_editor_query(sql: str) -> str: """Get the query part of an editor command.""" sql = sql.strip() + if sql in ('\\e', '/e', '\\edit', '/edit'): + return '' + # The reason we can't simply do .strip('\e') is that it strips characters, # not a substring. So it'll strip "e" in the end of the sql also! # Ex: "select * from style\e" -> "select * from styl". @@ -281,7 +283,7 @@ def clip_command(command: str) -> bool: """ # It is possible to have `\clip` or `SELECT * FROM \clip`. So we check # for both conditions. - return command.strip().endswith("\\clip") or command.strip().startswith("\\clip") + return command.strip().endswith("\\clip") or command.strip().startswith(("\\clip", "/clip")) def get_clip_query(sql: str) -> str: @@ -290,7 +292,7 @@ def get_clip_query(sql: str) -> str: # The reason we can't simply do .strip('\clip') is that it strips characters, # not a substring. So it'll strip "c" in the end of the sql also! - pattern = re.compile(r"(^\\clip|\\clip$)") + pattern = re.compile(r"(^\\clip|^/clip|\\clip$)") while pattern.search(sql): sql = pattern.sub("", sql) diff --git a/mycli/packages/special/llm.py b/mycli/packages/special/llm.py index e7786092..56fa7201 100644 --- a/mycli/packages/special/llm.py +++ b/mycli/packages/special/llm.py @@ -227,7 +227,7 @@ def handle_llm( _, command_verbosity, arg = parse_special_command(text) if not LLM_IMPORTED: raise FinishIteration(results=[SQLResult(preamble=NEED_DEPENDENCIES)]) - if arg.strip().lower() in ['', 'help', '?', r'\?']: + if arg.strip().lower() in ['', 'help', '/help', '?', r'\?', '/?']: raise FinishIteration(results=[SQLResult(preamble=USAGE)]) parts = shlex.split(arg) restart = False @@ -286,7 +286,7 @@ def handle_llm( def is_llm_command(command: str) -> bool: cmd, _, _ = parse_special_command(command) - return cmd in ("\\llm", "\\ai") + return cmd in ("\\llm", "/llm", "\\ai", "/ai") def truncate_list_elements(row: list, prompt_field_truncate: int, prompt_section_truncate: int) -> list: diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py index 12a6c7de..9463b46d 100644 --- a/mycli/packages/special/main.py +++ b/mycli/packages/special/main.py @@ -106,7 +106,12 @@ def register_special_command( case_sensitive: bool = False, aliases: list[SpecialCommandAlias] | None = None, ) -> None: + if command.startswith('\\'): + forwardslash_command = '/' + command.removeprefix('\\') + else: + forwardslash_command = '/' + command cmd = command.lower() if not case_sensitive else command + fcmd = forwardslash_command.lower() if not case_sensitive else forwardslash_command COMMANDS[cmd] = SpecialCommand( handler, command, @@ -117,17 +122,36 @@ def register_special_command( case_sensitive=case_sensitive, aliases=aliases, ) + COMMANDS[fcmd] = SpecialCommand( + handler, + command, + usage, + description, + arg_type=arg_type, + hidden=True, + case_sensitive=case_sensitive, + aliases=aliases, + ) if case_sensitive: CASE_SENSITIVE_COMMANDS.add(command) + CASE_SENSITIVE_COMMANDS.add(forwardslash_command) else: CASE_INSENSITIVE_COMMANDS.add(command.lower()) + CASE_INSENSITIVE_COMMANDS.add(forwardslash_command.lower()) aliases = [] if aliases is None else aliases for alias in aliases: + if alias.command.startswith('\\'): + forwardslash_alias_command = '/' + alias.command.removeprefix('\\') + else: + forwardslash_alias_command = '/' + alias.command cmd = alias.command.lower() if not alias.case_sensitive else alias.command + fcmd = forwardslash_alias_command.lower() if not alias.case_sensitive else forwardslash_alias_command if alias.case_sensitive: CASE_SENSITIVE_COMMANDS.add(alias.command) + CASE_SENSITIVE_COMMANDS.add(forwardslash_alias_command) else: CASE_INSENSITIVE_COMMANDS.add(alias.command.lower()) + CASE_INSENSITIVE_COMMANDS.add(forwardslash_alias_command.lower()) COMMANDS[cmd] = SpecialCommand( handler, command, @@ -138,6 +162,16 @@ def register_special_command( hidden=True, aliases=None, ) + COMMANDS[fcmd] = SpecialCommand( + handler, + command, + usage, + description, + arg_type=arg_type, + case_sensitive=alias.case_sensitive, + hidden=True, + aliases=None, + ) def execute(cur: Cursor, sql: str) -> list[SQLResult]: @@ -158,7 +192,7 @@ def execute(cur: Cursor, sql: str) -> list[SQLResult]: # "help is a special case. We want built-in help, not # mycli help here. - if command.lower() == "help" and arg: + if command.lower().startswith(("help", "/help", "\\?", "/?", "?")) and arg: return show_keyword_help(cur=cur, arg=arg) if special_cmd.arg_type == ArgType.NO_QUERY: diff --git a/mycli/packages/sql_utils.py b/mycli/packages/sql_utils.py index c03d5c85..18d095d2 100644 --- a/mycli/packages/sql_utils.py +++ b/mycli/packages/sql_utils.py @@ -431,7 +431,20 @@ def need_completion_refresh(queries: str) -> bool: for query in sqlparse.split(queries): try: first_token = query.split()[0] - if first_token.lower() in ("alter", "create", "use", "\\r", "\\u", "connect", "drop", "rename"): + if first_token.lower() in ( + "alter", + "create", + "use", + "/use", + "\\r", + "\\u", + "/r", + "/u", + "connect", + "/connect", + "drop", + "rename", + ): return True except Exception: continue @@ -447,9 +460,9 @@ def need_completion_reset(queries: str) -> bool: try: tokens = query.split() first_token = tokens[0] - if first_token.lower() in ("use", "\\u"): + if first_token.lower() in ("use", "/use", "\\u", "/u"): return True - if first_token.lower() in ("\\r", "connect") and len(tokens) > 1: + if first_token.lower() in ("\\r", "/r", "connect", "/connect") and len(tokens) > 1: return True except Exception: continue @@ -502,7 +515,7 @@ def classify_sandbox_statement(text: str) -> tuple[str | None, str | None]: return ('quit', None) # \q - if len(tokens) == 2 and types[0] == tt.BACKSLASH and texts[1] == 'Q': + if len(tokens) == 2 and types[0] in (tt.BACKSLASH, tt.SLASH) and texts[1] in ('Q', 'QUIT', 'EXIT'): return ('quit', None) # ALTER USER ... diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 7af25920..c4661b90 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -411,7 +411,7 @@ def run(self, statement: str) -> Generator[SQLResult, None, None]: # Split the sql into separate queries and run each one. # Unless it's saving a favorite query, in which case we # want to save them all together. - if statement.startswith("\\fs"): + if statement.startswith(("\\fs", "/fs")): components: Iterable[str] = [statement] else: components = iocommands.split_queries(statement) diff --git a/test/pytests/test_clibuffer.py b/test/pytests/test_clibuffer.py index d502e009..f777abc1 100644 --- a/test/pytests/test_clibuffer.py +++ b/test/pytests/test_clibuffer.py @@ -61,7 +61,8 @@ def test_multiline_exception_detects_commands_terminators_and_plain_sql( expected: bool, ) -> None: monkeypatch.setattr(clibuffer.iocommands, 'get_current_delimiter', lambda: '//') - monkeypatch.setattr(clibuffer, 'SPECIAL_COMMANDS', {'help': object(), 'exit': object()}) + monkeypatch.setattr(clibuffer, 'CASE_SENSITIVE_COMMANDS', {'Camel'}) + monkeypatch.setattr(clibuffer, 'CASE_INSENSITIVE_COMMANDS', {'help', 'exit'}) assert clibuffer._multiline_exception(text) is expected @@ -85,7 +86,8 @@ def test_multiline_exception_recognizes_non_backslashed_special_commands_with_ge text: str, ) -> None: monkeypatch.setattr(clibuffer.iocommands, 'get_current_delimiter', lambda: ';') - monkeypatch.setattr(clibuffer, 'SPECIAL_COMMANDS', {'help': object(), 'exit': object()}) + monkeypatch.setattr(clibuffer, 'CASE_SENSITIVE_COMMANDS', {'Camel'}) + monkeypatch.setattr(clibuffer, 'CASE_INSENSITIVE_COMMANDS', {'help', 'exit'}) assert clibuffer._multiline_exception(text) is True @@ -107,7 +109,8 @@ def test_cli_is_multiline_uses_buffer_text_when_multiline_mode_is_enabled( monkeypatch.setattr(clibuffer, 'get_app', lambda: app) monkeypatch.setattr(clibuffer.iocommands, 'get_current_delimiter', lambda: ';') - monkeypatch.setattr(clibuffer, 'SPECIAL_COMMANDS', {'help': object()}) + monkeypatch.setattr(clibuffer, 'CASE_SENSITIVE_COMMANDS', {'Camel'}) + monkeypatch.setattr(clibuffer, 'CASE_INSENSITIVE_COMMANDS', {'help'}) multiline_filter = clibuffer.cli_is_multiline(mycli) diff --git a/test/pytests/test_special_iocommands.py b/test/pytests/test_special_iocommands.py index ee8a73ef..00d14d27 100644 --- a/test/pytests/test_special_iocommands.py +++ b/test/pytests/test_special_iocommands.py @@ -174,6 +174,7 @@ def test_editor_command(monkeypatch): assert mycli.packages.special.editor_command(r"hello\edit") assert mycli.packages.special.editor_command(r"\e hello") assert mycli.packages.special.editor_command(r"\edit hello") + assert mycli.packages.special.editor_command('/edit') assert not mycli.packages.special.editor_command(r"HELP \e") assert not mycli.packages.special.editor_command(r"help \edit\g") @@ -182,6 +183,7 @@ def test_editor_command(monkeypatch): assert not mycli.packages.special.editor_command(r"\edithello") assert mycli.packages.special.get_filename(r"\e filename") == "filename" + assert mycli.packages.special.get_editor_query('/edit') == '' if os.name != "nt": assert mycli.packages.special.open_external_editor(sql=r"select 1") == ('select 1', None) diff --git a/test/pytests/test_special_main.py b/test/pytests/test_special_main.py index 3c1b2e77..6c2e620f 100644 --- a/test/pytests/test_special_main.py +++ b/test/pytests/test_special_main.py @@ -120,7 +120,7 @@ def test_register_special_command_tracks_case_insensitive_commands(restore_comma ) assert special_main.CASE_SENSITIVE_COMMANDS == set() - assert special_main.CASE_INSENSITIVE_COMMANDS == {'demo', '\\d'} + assert special_main.CASE_INSENSITIVE_COMMANDS == {'demo', '/demo', '\\d', '/d'} def test_special_command_decorator_registers_case_sensitive_command(restore_commands: None) -> None: @@ -134,8 +134,10 @@ def handler() -> None: assert special_main.COMMANDS['Camel'].handler is handler assert 'Camel' in special_main.CASE_SENSITIVE_COMMANDS + assert '/Camel' in special_main.CASE_SENSITIVE_COMMANDS assert special_main.CASE_INSENSITIVE_COMMANDS == set() assert 'camel' not in special_main.COMMANDS + assert '/camel' not in special_main.COMMANDS def test_execute_raises_when_command_is_missing() -> None: