Skip to content

Commit 9c426b7

Browse files
authored
Merge pull request #14 from Big-Life-Lab/code-fixes
misc CLI fixes
2 parents ac26d5a + 2948229 commit 9c426b7

File tree

7 files changed

+147
-56
lines changed

7 files changed

+147
-56
lines changed

docs/tech-spec.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ Options:
5959
output files. This shows which tables and columns are selected, and how
6060
many rows each filter returns.
6161

62+
- `-q`, `--quiet`:
63+
64+
don't log to STDOUT
65+
6266
One or multiple sharable output files will be created in the chosen output
6367
directory according to the chosen output format and organization(s). Each
6468
output file will have the input filename followed by a postfix with the org

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ dynamic = ["dependencies"]
1818
"Homepage" = "https://github.com/Big-Life-Lab/PHES-ODM-sharing"
1919
"Bug Tracker" = "https://github.com/Big-Life-Lab/PHES-ODM-sharing/issues"
2020

21+
[project.scripts]
22+
odm-share = "odm_sharing.tools.share:main"
23+
2124
[build-system]
2225
requires = ["hatchling", "hatch-requirements-txt"]
2326
build-backend = "hatchling.build"

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
SQLAlchemy==2.0.29
2+
numpy==1.24.4
23
openpyxl==3.1.2
34
pandas==2.0.3
45
pyfunctional==1.5.0

src/odm_sharing/private/cons.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,60 @@
1-
from typing import List, cast
1+
import logging
2+
from pathlib import Path
3+
from typing import List, Set
24

35
import pandas as pd
46
import sqlalchemy as sa
57

68

7-
Connection = object # opaque data-source connection handle
9+
Connection = sa.engine.Engine
810

911

1012
class DataSourceError(Exception):
1113
pass
1214

1315

14-
def _connect_excel(path: str, tables: List[str]) -> Connection:
15-
''':raises OSError:'''
16-
# copies excel data to in-memory db, to abstract everything as a db
17-
print('importing excel workbook')
18-
table_whitelist = set(tables)
19-
db = sa.create_engine('sqlite://', echo=False)
16+
def _create_memory_db() -> sa.engine.Engine:
17+
return sa.create_engine('sqlite://', echo=False)
18+
19+
20+
def _write_table_to_db(db: sa.engine.Engine, table: str, df: pd.DataFrame
21+
) -> None:
22+
logging.info(f'- table {table}')
23+
df.to_sql(table, db, index=False, if_exists='replace')
24+
25+
26+
def _connect_csv(path: str) -> Connection:
27+
'''copies file data to in-memory db
28+
29+
:raises OSError:'''
30+
logging.info('importing csv file')
31+
table = Path(path).stem
32+
db = _create_memory_db()
33+
df = pd.read_csv(path)
34+
_write_table_to_db(db, table, df)
35+
return db
36+
37+
38+
def _connect_excel(path: str, table_whitelist: Set[str]) -> Connection:
39+
'''copies file data to in-memory db
40+
41+
:raises OSError:'''
42+
logging.info('importing excel workbook')
43+
db = _create_memory_db()
2044
xl = pd.ExcelFile(path)
2145
included_tables = set(map(str, xl.sheet_names)) & table_whitelist
2246
for table in included_tables:
23-
print(f'- table {table}')
2447
df = xl.parse(sheet_name=table)
25-
df.to_sql(table, db, index=False, if_exists='replace')
26-
return cast(Connection, db)
48+
_write_table_to_db(db, table, df)
49+
return db
2750

2851

2952
def _connect_db(url: str) -> Connection:
3053
''':raises sa.exc.OperationalError:'''
3154
return sa.create_engine(url)
3255

3356

34-
def connect(data_source: str, tables: List[str] = []) -> Connection:
57+
def connect(data_source: str, tables: Set[str] = set()) -> Connection:
3558
'''
3659
connects to a data source and returns the connection
3760
@@ -41,7 +64,9 @@ def connect(data_source: str, tables: List[str] = []) -> Connection:
4164
:raises DataSourceError:
4265
'''
4366
try:
44-
if data_source.endswith('.xlsx'):
67+
if data_source.endswith('.csv'):
68+
return _connect_csv(data_source)
69+
elif data_source.endswith('.xlsx'):
4570
return _connect_excel(data_source, tables)
4671
else:
4772
return _connect_db(data_source)
@@ -51,16 +76,15 @@ def connect(data_source: str, tables: List[str] = []) -> Connection:
5176

5277
def get_dialect_name(c: Connection) -> str:
5378
'''returns the name of the dialect used for the connection'''
54-
return cast(sa.engine.Engine, c).dialect.name
79+
return c.dialect.name
5580

5681

5782
def exec(c: Connection, sql: str, sql_args: List[str] = []) -> pd.DataFrame:
5883
'''executes sql with args on connection
5984
6085
:raises DataSourceError:
6186
'''
62-
db = cast(sa.engine.Engine, c)
6387
try:
64-
return pd.read_sql_query(sql, db, params=tuple(sql_args))
88+
return pd.read_sql_query(sql, c, params=tuple(sql_args))
6589
except sa.exc.OperationalError as e:
6690
raise DataSourceError(str(e))

src/odm_sharing/private/stdext.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
class StrEnum(str, Enum):
77
'''shim for python < 3.11
88
9-
Provides a ``__str__()`` function that returns the enum string-value, which
10-
is useful for printing the value or comparing it with another string.
9+
Gives the enum's assigned string value when converted to string, which is
10+
useful for printing the value or comparing it with another string.
1111
1212
See https://docs.python.org/3.11/library/enum.html#enum.StrEnum
1313
'''
14-
pass
14+
def __str__(self) -> str:
15+
return str(self.value)
1516

1617

1718
class StrValueEnum(StrEnum):

src/odm_sharing/sharing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Dict, List, Tuple
2+
from typing import Dict, List, Set, Tuple
33

44
import pandas as pd
55
from functional import seq
@@ -29,7 +29,7 @@ def parse(schema_path: str, orgs: List[str] = []) -> OrgTableQueries:
2929
return queries.generate(tree)
3030

3131

32-
def connect(data_source: str, tables: List[str] = []) -> Connection:
32+
def connect(data_source: str, tables: Set[str] = set()) -> Connection:
3333
'''returns a connection object that can be used together with a query
3434
object to retrieve data from `data_source`
3535
@@ -67,7 +67,7 @@ def get_columns(c: Connection, tq: TableQuery
6767
if tq.columns:
6868
return (tq.select_rule_id, tq.columns)
6969
else:
70-
dialect = queries.SqlDialect(cons.get_dialect_name(c))
70+
dialect = queries.parse_sql_dialect(cons.get_dialect_name(c))
7171
sql = queries.get_column_sql(tq, dialect)
7272
columns = cons.exec(c, sql).columns.array.tolist()
7373
return (tq.select_rule_id, columns)

tools/share.py renamed to src/odm_sharing/tools/share.py

Lines changed: 92 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import contextlib
2+
import logging
23
import os
4+
import sys
35
from enum import Enum
46
from os import linesep
57
from pathlib import Path
6-
from typing import Dict, List, Optional, Set, TextIO
8+
from typing import Dict, List, Optional, Set, TextIO, Union
79
from typing_extensions import Annotated
810

911
import pandas as pd
1012
import typer
11-
import sqlalchemy as sa
1213
from tabulate import tabulate
1314
from functional import seq
1415

@@ -24,6 +25,7 @@
2425

2526
class OutFmt(str, Enum):
2627
'''output format'''
28+
AUTO = 'auto'
2729
CSV = 'csv'
2830
EXCEL = 'excel'
2931

@@ -41,10 +43,23 @@ class OutFmt(str, Enum):
4143
creating sharable output files. This shows which tables and columns are
4244
selected, and how many rows each filter returns.'''
4345

46+
QUIET_DESC = 'Don\'t log to STDOUT.'
47+
48+
# default cli args
49+
DEBUG_DEFAULT = False
50+
ORGS_DEFAULT = []
51+
OUTDIR_DEFAULT = './'
52+
OUTFMT_DEFAULT = OutFmt.AUTO
53+
QUIET_DEFAULT = False
4454

4555
app = typer.Typer(pretty_exceptions_show_locals=False)
4656

4757

58+
def error(msg: str) -> None:
59+
print(msg, file=sys.stderr)
60+
logging.error(msg)
61+
62+
4863
def write_line(file: TextIO, text: str = '') -> None:
4964
'''writes a line to STDOUT and file'''
5065
print(text)
@@ -109,57 +124,76 @@ def get_tables(org_queries: sh.queries.OrgTableQueries) -> Set[str]:
109124
return result
110125

111126

112-
def gen_filename(org: str, table: str, ext: str) -> str:
113-
# <org>[-<table>].<ext>
114-
return org + (f'-{table}' if table else '') + f'.{ext}'
127+
def gen_filename(in_name: str, org: str, table: str, ext: str) -> str:
128+
if in_name == table or not table:
129+
# this avoids duplicating the table name when both input and output is
130+
# CSV
131+
return f'{in_name}-{org}.{ext}'
132+
else:
133+
return f'{in_name}-{org}-{table}.{ext}'
115134

116135

117-
def get_debug_writer(debug: bool) -> TextIO:
136+
def get_debug_writer(debug: bool) -> Union[TextIO, contextlib.nullcontext]:
118137
# XXX: this function is only used for brewity with the below `with` clause
119138
if debug:
120139
return open('debug.txt', 'w')
121140
else:
122141
return contextlib.nullcontext()
123142

124143

125-
def get_excel_writer(debug: bool, org: str, outdir: str, outfmt: OutFmt
126-
) -> Optional[pd.ExcelWriter]:
144+
def get_excel_writer(in_name, debug: bool, org: str, outdir: str,
145+
outfmt: OutFmt) -> Optional[pd.ExcelWriter]:
127146
if not debug and outfmt == OutFmt.EXCEL:
128-
filename = gen_filename(org, '', 'xlsx')
129-
print('writing ' + filename)
147+
filename = gen_filename(in_name, org, '', 'xlsx')
148+
logging.info('writing ' + filename)
130149
excel_path = os.path.join(outdir, filename)
131150
return pd.ExcelWriter(excel_path)
151+
else:
152+
return None
132153

133154

134-
@app.command()
135-
def main(
136-
schema: str = typer.Argument(default=..., help=SCHEMA_DESC),
137-
input: str = typer.Argument(default='', help=INPUT_DESC),
138-
orgs: List[str] = typer.Option(default=[], help=ORGS_DESC),
139-
outfmt: OutFmt = typer.Option(default=OutFmt.EXCEL, help=OUTFMT_DESC),
140-
outdir: str = typer.Option(default='./', help=OUTDIR_DESC),
141-
debug: Annotated[bool, typer.Option("-d", "--debug",
142-
help=DEBUG_DESC)] = False,
155+
def infer_outfmt(path: str) -> Optional[OutFmt]:
156+
'''returns None when not recognized'''
157+
(_, ext) = os.path.splitext(path)
158+
if ext == '.csv':
159+
return OutFmt.CSV
160+
elif ext == '.xlsx':
161+
return OutFmt.EXCEL
162+
163+
164+
def share(
165+
schema: str,
166+
input: str,
167+
orgs: List[str] = ORGS_DEFAULT,
168+
outfmt: OutFmt = OUTFMT_DEFAULT,
169+
outdir: str = OUTDIR_DEFAULT,
170+
debug: bool = DEBUG_DEFAULT,
143171
) -> None:
144172
schema_path = schema
145-
filename = Path(schema_path).name
173+
schema_filename = Path(schema_path).name
174+
in_name = Path(input).stem
175+
176+
if outfmt == OutFmt.AUTO:
177+
fmt = infer_outfmt(input)
178+
if not fmt:
179+
error('unable to infer output format from input file')
180+
return
181+
outfmt = fmt
146182

147-
print(f'loading schema {qt(filename)}')
183+
logging.info(f'loading schema {qt(schema_filename)}')
148184
try:
149185
ruleset = rules.load(schema_path)
150-
ruletree = trees.parse(ruleset, orgs, filename)
186+
ruletree = trees.parse(ruleset, orgs, schema_filename)
151187
org_queries = queries.generate(ruletree)
152188
table_filter = get_tables(org_queries)
153189
except rules.ParseError:
154190
# XXX: error messages are already printed at this point
155-
exit(1)
191+
return
156192

157193
# XXX: only tables found in the schema are considered in the data source
158-
print(f'connecting to {qt(input)}')
194+
logging.info(f'connecting to {qt(input)}')
159195
con = sh.connect(input, table_filter)
160196

161-
if debug:
162-
print()
163197
# one debug file per run
164198
with get_debug_writer(debug) as debug_file:
165199
for org, table_queries in org_queries.items():
@@ -172,26 +206,50 @@ def main(
172206
org_data[table] = sh.get_data(con, tq)
173207

174208
# one excel file per org
175-
excel_file = get_excel_writer(debug, org, outdir, outfmt)
209+
excel_file = get_excel_writer(in_name, debug, org, outdir, outfmt)
176210
try:
177211
for table, data in org_data.items():
178212
if outfmt == OutFmt.CSV:
179-
filename = gen_filename(org, table, 'csv')
180-
print('writing ' + filename)
181-
data.to_csv(os.path.join(outdir, filename))
213+
filename = gen_filename(in_name, org, table, 'csv')
214+
logging.info('writing ' + filename)
215+
path = os.path.join(outdir, filename)
216+
data.to_csv(path, index=False)
182217
elif outfmt == OutFmt.EXCEL:
183-
print(f'- {qt(table)}')
218+
logging.info(f'- {qt(table)}')
184219
data.to_excel(excel_file, sheet_name=table)
185220
else:
186221
assert False, f'format {outfmt} not impl'
187222
except IndexError:
188223
# XXX: this is thrown from excel writer when nothing is written
189-
exit('failed to write output, most likely due to empty input')
224+
error('failed to write output, most likely due to empty input')
225+
return
190226
finally:
191227
if excel_file:
192228
excel_file.close()
193-
print('done')
229+
logging.info('done')
194230

195231

196-
if __name__ == '__main__':
232+
@app.command()
233+
def main_cli(
234+
schema: str = typer.Argument(default=..., help=SCHEMA_DESC),
235+
input: str = typer.Argument(default='', help=INPUT_DESC),
236+
orgs: List[str] = typer.Option(default=ORGS_DEFAULT, help=ORGS_DESC),
237+
outfmt: OutFmt = typer.Option(default=OUTFMT_DEFAULT, help=OUTFMT_DESC),
238+
outdir: str = typer.Option(default=OUTDIR_DEFAULT, help=OUTDIR_DESC),
239+
debug: Annotated[bool, typer.Option("-d", "--debug",
240+
help=DEBUG_DESC)] = DEBUG_DEFAULT,
241+
quiet: Annotated[bool, typer.Option("-q", "--quiet",
242+
help=QUIET_DESC)] = QUIET_DEFAULT,
243+
) -> None:
244+
if not quiet:
245+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
246+
share(schema, input, orgs, outfmt, outdir, debug)
247+
248+
249+
def main():
250+
# runs main_cli
197251
app()
252+
253+
254+
if __name__ == '__main__':
255+
main()

0 commit comments

Comments
 (0)