Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
## Unreleased

### Features

- Add a `dot` table format for exporting query results as Graphviz DOT.

### Bug Fixes

- Expand `~` in configured log file paths before opening the log.
Expand Down
2 changes: 1 addition & 1 deletion litecli/liteclirc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ log_level = INFO
# Table format. Possible values:
# ascii, double, github, psql, plain, simple, grid, fancy_grid, pipe, orgtbl,
# rst, mediawiki, html, latex, latex_booktabs, textile, moinmoin, jira,
# vertical, tsv, csv.
# vertical, tsv, csv, dot.
# Recommended: ascii
table_format = ascii

Expand Down
46 changes: 37 additions & 9 deletions litecli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .key_bindings import cli_bindings
from .lexer import LiteCliLexer
from .packages import special
from .packages.dot_output import format_dot_output
from .packages.filepaths import dir_path_exists
from .packages.prompt_utils import confirm, confirm_destructive_query
from .packages.special.main import NO_QUERY
Expand All @@ -60,6 +61,7 @@ def _load_sqlite3() -> Any:
_sqlite3 = _load_sqlite3()
OperationalError = _sqlite3.OperationalError
sqlite_version = _sqlite3.sqlite_version
LOCAL_OUTPUT_FORMATS = ("dot",)

# Query tuples are used for maintaining history
Query = namedtuple("Query", ["query", "successful", "mutating"])
Expand Down Expand Up @@ -89,7 +91,11 @@ def __init__(
self.multi_line = c["main"].as_bool("multi_line")
self.key_bindings = c["main"]["key_bindings"]
special.set_favorite_queries(self.config)
self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
self.local_format_name: str | None = None
config_table_format = c["main"]["table_format"]
self.formatter = TabularOutputFormatter(format_name="ascii" if config_table_format in LOCAL_OUTPUT_FORMATS else config_table_format)
if config_table_format in LOCAL_OUTPUT_FORMATS:
self.local_format_name = config_table_format
# self.formatter.litecli = self, ty raises unresolved-attribute, hence use dynamic assignment
setattr(self.formatter, "litecli", self)
self.syntax_style = c["main"]["syntax_style"]
Expand Down Expand Up @@ -137,7 +143,7 @@ def __init__(

# Initialize completer.
self.completer = SQLCompleter(
supported_formats=self.formatter.supported_formats,
supported_formats=self.supported_table_formats(),
keyword_casing=keyword_casing,
)
self._completer_lock = threading.Lock()
Expand Down Expand Up @@ -188,13 +194,31 @@ def register_special_commands(self) -> None:
case_sensitive=True,
)

def supported_table_formats(self) -> list[str]:
supported_formats = list(self.formatter.supported_formats)
for format_name in LOCAL_OUTPUT_FORMATS:
if format_name not in supported_formats:
supported_formats.append(format_name)
return supported_formats

def current_table_format(self) -> str:
return self.local_format_name or self.formatter.format_name

def set_table_format(self, format_name: str) -> None:
if format_name in LOCAL_OUTPUT_FORMATS:
self.local_format_name = format_name
return

self.formatter.format_name = format_name
self.local_format_name = None

def change_table_format(self, arg: str, **_: Any) -> Generator[tuple[None, None, None, str], None, None]:
try:
self.formatter.format_name = arg
self.set_table_format(arg)
yield (None, None, None, "Changed table format to {}".format(arg))
except ValueError:
msg = "Table format {} not recognized. Allowed formats:".format(arg)
for table_type in self.formatter.supported_formats:
for table_type in self.supported_table_formats():
msg += "\n\t{}".format(table_type)
yield (None, None, None, msg)

Expand Down Expand Up @@ -839,7 +863,8 @@ def run_query(self, query: str, new_line: bool = True) -> None:
click.echo(line, nl=new_line)

def format_output(self, title: Any, cur: Any, headers: Any, expanded: bool = False, max_width: int | None = None) -> Iterable[str]:
expanded = expanded or self.formatter.format_name == "vertical"
format_name = self.current_table_format()
expanded = expanded or format_name == "vertical"
output_iter: Iterable[str] = []

output_kwargs = {
Expand All @@ -854,6 +879,9 @@ def format_output(self, title: Any, cur: Any, headers: Any, expanded: bool = Fal
output_iter = itertools.chain(output_iter, [title])

if cur:
if format_name == "dot":
return itertools.chain(output_iter, format_dot_output(cur, headers or []))

column_types = None
if hasattr(cur, "description"):
column_types = [str(col) for col in cur.description]
Expand Down Expand Up @@ -972,9 +1000,9 @@ def cli(
if execute:
try:
if csv:
litecli.formatter.format_name = "csv"
litecli.set_table_format("csv")
elif not table:
litecli.formatter.format_name = "tsv"
litecli.set_table_format("tsv")

litecli.run_query(execute)
exit(0)
Expand All @@ -999,9 +1027,9 @@ def cli(
new_line = True

if csv:
litecli.formatter.format_name = "csv"
litecli.set_table_format("csv")
elif not table:
litecli.formatter.format_name = "tsv"
litecli.set_table_format("tsv")

litecli.run_query(stdin_text, new_line=new_line)
exit(0)
Expand Down
45 changes: 45 additions & 0 deletions litecli/packages/dot_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

from typing import Any, Iterable


def _dot_value(value: Any) -> str:
if value is None:
return "NULL"
return str(value)


def _dot_quote(value: Any) -> str:
text = _dot_value(value)
return '"' + text.replace("\\", "\\\\").replace('"', '\\"').replace("\n", "\\n").replace("\r", "\\r") + '"'


def format_dot_output(rows: Iterable[Iterable[Any]], headers: Iterable[str]) -> Iterable[str]:
"""Format one-column results as nodes and multi-column results as edges."""
header_names = list(headers)

yield "digraph result {"

if header_names:
yield " // Columns: {}".format(", ".join(header_names))

for row in rows:
row_values = list(row)
if not row_values:
continue

if len(row_values) == 1:
yield " {};".format(_dot_quote(row_values[0]))
continue

label = ""
if len(row_values) > 2:
label_values = []
for index, value in enumerate(row_values[2:], start=2):
column_name = header_names[index] if index < len(header_names) else "column{}".format(index + 1)
label_values.append("{}={}".format(column_name, _dot_value(value)))
label = " [label={}]".format(_dot_quote(", ".join(label_values)))

yield " {} -> {}{};".format(_dot_quote(row_values[0]), _dot_quote(row_values[1]), label)

yield "}"
2 changes: 1 addition & 1 deletion tests/liteclirc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ log_level = INFO
# Table format. Possible values:
# ascii, double, github, psql, plain, simple, grid, fancy_grid, pipe, orgtbl,
# rst, mediawiki, html, latex, latex_booktabs, textile, moinmoin, jira,
# vertical, tsv, csv.
# vertical, tsv, csv, dot.
# Recommended: ascii
table_format = ascii

Expand Down
39 changes: 39 additions & 0 deletions tests/test_dot_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from litecli.packages.dot_output import format_dot_output


def test_dot_output_formats_edges():
rows = [("orders", "customers"), ("line_items", "orders")]
headers = ["child", "parent"]

assert list(format_dot_output(rows, headers)) == [
"digraph result {",
" // Columns: child, parent",
' "orders" -> "customers";',
' "line_items" -> "orders";',
"}",
]


def test_dot_output_formats_nodes_and_escapes_values():
rows = [('a"b',), ("line\nbreak",), (None,)]

assert list(format_dot_output(rows, ["name"])) == [
"digraph result {",
" // Columns: name",
' "a\\"b";',
' "line\\nbreak";',
' "NULL";',
"}",
]


def test_dot_output_uses_extra_columns_as_edge_label():
rows = [("a", "b", "foreign key")]
headers = ["source", "target", "relation"]

assert list(format_dot_output(rows, headers)) == [
"digraph result {",
" // Columns: source, target, relation",
' "a" -> "b" [label="relation=foreign key"];',
"}",
]
13 changes: 13 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,19 @@ def test_batch_mode_csv(executor):
assert expected in "".join(result.output)


def test_dot_table_format_is_supported():
m = LiteCli(liteclirc=default_config_file)

assert "dot" in m.supported_table_formats()
assert list(m.change_table_format("dot")) == [(None, None, None, "Changed table format to dot")]
assert list(m.format_output(None, [("orders", "customers")], ["source", "target"])) == [
"digraph result {",
" // Columns: source, target",
' "orders" -> "customers";',
"}",
]


def test_help_strings_end_with_periods():
"""Make sure click options have help text that end with a period."""
for param in cli.params:
Expand Down
Loading