Skip to content

Commit 2f65c43

Browse files
committed
Import modules when init sqldriver driver
This loads modules necessary for the sqldriver drivers on init, with a popup if module is not found.
1 parent 3527c37 commit 2f65c43

File tree

2 files changed

+98
-41
lines changed

2 files changed

+98
-41
lines changed

examples/MSAccess_examples/install_java.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@
66
run. This also serves as an example to automatically download a local Java installation
77
for your own projects.
88
"""
9-
import jdk
109
import os
1110
import pysimplesql as ss
1211
import PySimpleGUI as sg
1312
import subprocess
1413

14+
try:
15+
import jdk
16+
except ModuleNotFoundError:
17+
sg.popup_error("You must `pip install install-jdk` to use this example")
18+
exit(0)
19+
1520

1621
# -------------------------------------------------
1722
# ROUTINES TO INSTALL JAVA IF USER DOES NOT HAVE IT

pysimplesql/pysimplesql.py

+92-40
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
from time import sleep, time # threaded popup
6767
from typing import Callable, Dict, List, Optional, Tuple, Type, TypedDict, Union # docs
6868

69-
import jpype # pip install JPype1
7069
import PySimpleGUI as sg
7170

7271
# Wrap optional imports so that pysimplesql can be imported as a single file if desired:
@@ -91,38 +90,6 @@
9190
}
9291
# fmt: on
9392

94-
# Load database backends if present
95-
supported_databases = ["SQLite3", "MySQL", "PostgreSQL", "Flatfile", "Sqlserver"]
96-
failed_modules = 0
97-
try:
98-
import sqlite3
99-
except ModuleNotFoundError:
100-
failed_modules += 1
101-
try:
102-
import mysql.connector # mysql-connector-python
103-
except ModuleNotFoundError:
104-
failed_modules += 1
105-
try:
106-
import psycopg2
107-
import psycopg2.extras
108-
except ModuleNotFoundError:
109-
failed_modules += 1
110-
try:
111-
import csv
112-
except ModuleNotFoundError:
113-
failed_modules += 1
114-
try:
115-
import pyodbc
116-
except ModuleNotFoundError:
117-
failed_modules += 1
118-
119-
if failed_modules == len(supported_databases):
120-
RuntimeError(
121-
f"You muse have at least one of the following databases installed to use "
122-
f"PySimpleSQL:\n{', '.join(supported_databases)} "
123-
)
124-
125-
12693
logger = logging.getLogger(__name__)
12794

12895
# ---------------------------
@@ -5325,6 +5292,11 @@ class LanguagePack:
53255292
# Quick Editor
53265293
# ------------------------------------------------------------------------------
53275294
"quick_edit_title": "Quick Edit - {data_key}",
5295+
# ------------------------------------------------------------------------------
5296+
# Error when importing module for driver
5297+
# ------------------------------------------------------------------------------
5298+
"import_module_failed_title": "Problem importing module",
5299+
"import_module_failed": "Unable to import module neccessary for {name}\nException: {exception}\n\nTry `pip install {requires}`", # fmt: skip # noqa: E501
53285300
}
53295301
"""
53305302
Default LanguagePack.
@@ -6140,6 +6112,7 @@ class SQLDriver:
61406112
def __init__(
61416113
self,
61426114
name: str,
6115+
requires: List[str],
61436116
placeholder="%s",
61446117
table_quote="",
61456118
column_quote="",
@@ -6153,6 +6126,7 @@ def __init__(
61536126
# Be sure to call super().__init__() in derived class!
61546127
self.con = None
61556128
self.name = name
6129+
self.requires = requires
61566130
self._check_reserved_keywords = True
61576131
self.win_pb = ProgressBar(
61586132
lang.sqldriver_init.format_map(LangFormat(name=name)), 100
@@ -6177,6 +6151,17 @@ def __init__(
61776151
# override this in derived __init__() (defaults to single quotes)
61786152
self.quote_value_char = value_quote
61796153

6154+
def import_failed(self, exception) -> None:
6155+
popup = Popup()
6156+
requires = ", ".join(self.requires)
6157+
popup.ok(
6158+
lang.import_module_failed_title,
6159+
lang.import_module_failed.format_map(
6160+
LangFormat(name=self.name, requires=requires, exception=exception)
6161+
),
6162+
)
6163+
exit(0)
6164+
61806165
def check_reserved_keywords(self, value: bool) -> None:
61816166
"""
61826167
SQLDrivers can check to make sure that field names respect their own reserved
@@ -6639,11 +6624,14 @@ def __init__(
66396624
):
66406625
super().__init__(
66416626
name="SQLite",
6627+
requires=["sqlite3"],
66426628
placeholder="?",
66436629
table_quote='"',
66446630
column_quote='"',
66456631
)
66466632

6633+
self.import_required_modules()
6634+
66476635
new_database = False
66486636
if db_path is not None:
66496637
logger.info(f"Opening database: {db_path}")
@@ -6672,6 +6660,13 @@ def __init__(
66726660
self.db_path = db_path
66736661
self.win_pb.close()
66746662

6663+
def import_required_modules(self):
6664+
global sqlite3 # noqa PLW0603
6665+
try:
6666+
import sqlite3
6667+
except ModuleNotFoundError as e:
6668+
self.import_failed(e)
6669+
66756670
def connect(self, database):
66766671
self.con = sqlite3.connect(database)
66776672

@@ -6840,9 +6835,13 @@ def __init__(
68406835
# First up the SQLite driver that we derived from
68416836
super().__init__(":memory:") # use an in-memory database
68426837

6843-
# Store our Flatfile-specific information
6838+
# Change Sqlite Sqldriver init set values to Flatfile-specific
68446839
self.name = "Flatfile"
6840+
self.requires = ["csv,sqlite3"]
68456841
self.placeholder = "?" # update
6842+
6843+
self.import_required_modules()
6844+
68466845
self.connect(":memory:")
68476846
self.file_path = file_path
68486847
self.delimiter = delimiter
@@ -6907,6 +6906,15 @@ def __init__(
69076906
self.commit() # commit them all at the end
69086907
self.win_pb.close()
69096908

6909+
def import_required_modules(self):
6910+
global csv # noqa PLW0603
6911+
global sqlite3 # noqa PLW0603
6912+
try:
6913+
import csv
6914+
import sqlite3
6915+
except ModuleNotFoundError as e:
6916+
self.import_failed(e)
6917+
69106918
def save_record(
69116919
self, dataset: DataSet, changed_row: dict, where_clause: str = None
69126920
) -> ResultSet:
@@ -6959,9 +6967,11 @@ class Mysql(SQLDriver):
69596967
def __init__(
69606968
self, host, user, password, database, sql_script=None, sql_commands=None
69616969
):
6962-
super().__init__(name="MySQL")
6970+
super().__init__(name="MySQL", requires=["mysql-connector-python"])
69636971

6964-
self.name = "MySQL"
6972+
self.import_required_modules()
6973+
6974+
self.name = "MySQL" # is this redundant?
69656975
self.host = host
69666976
self.user = user
69676977
self.password = password
@@ -6982,6 +6992,13 @@ def __init__(
69826992

69836993
self.win_pb.close()
69846994

6995+
def import_required_modules(self):
6996+
global mysql # noqa PLW0603
6997+
try:
6998+
import mysql.connector
6999+
except ModuleNotFoundError as e:
7000+
self.import_failed(e)
7001+
69857002
def connect(self, retries=3):
69867003
attempt = 0
69877004
while attempt < retries:
@@ -7158,7 +7175,11 @@ def __init__(
71587175
sql_commands=None,
71597176
sync_sequences=False,
71607177
):
7161-
super().__init__(name="Postgres", table_quote='"')
7178+
super().__init__(
7179+
name="Postgres", requires=["psycopg2", "psycopg2.extras"], table_quote='"'
7180+
)
7181+
7182+
self.import_required_modules()
71627183

71637184
self.host = host
71647185
self.user = user
@@ -7213,6 +7234,14 @@ def __init__(
72137234
self.execute_script(sql_script)
72147235
self.win_pb.close()
72157236

7237+
def import_required_modules(self):
7238+
global psycopg2 # noqa PLW0603
7239+
try:
7240+
import psycopg2
7241+
import psycopg2.extras
7242+
except ModuleNotFoundError as e:
7243+
self.import_failed(e)
7244+
72167245
def connect(self, retries=3):
72177246
attempt = 0
72187247
while attempt < retries:
@@ -7413,9 +7442,13 @@ class Sqlserver(SQLDriver):
74137442
def __init__(
74147443
self, host, user, password, database, sql_script=None, sql_commands=None
74157444
):
7416-
super().__init__(name="Sqlserver", table_quote='"', placeholder="?")
7445+
super().__init__(
7446+
name="Sqlserver", requires=["pyodbc"], table_quote='"', placeholder="?"
7447+
)
74177448

7418-
self.name = "Sqlserver"
7449+
self.import_required_modules()
7450+
7451+
self.name = "Sqlserver" # is this redundant?
74197452
self.host = host
74207453
self.user = user
74217454
self.password = password
@@ -7435,6 +7468,13 @@ def __init__(
74357468

74367469
self.win_pb.close()
74377470

7471+
def import_required_modules(self):
7472+
global pyodbc # noqa PLW0603
7473+
try:
7474+
import pyodbc
7475+
except ModuleNotFoundError as e:
7476+
self.import_failed(e)
7477+
74387478
def connect(self, retries=3, timeout=3):
74397479
attempt = 0
74407480
while attempt < retries:
@@ -7598,13 +7638,25 @@ class MSAccess(SQLDriver):
75987638
"""
75997639

76007640
def __init__(self, database_file):
7601-
super().__init__(name="MSAccess", table_quote="[]", placeholder="?")
7641+
super().__init__(
7642+
name="MSAccess", requires=["Jype1"], table_quote="[]", placeholder="?"
7643+
)
7644+
7645+
self.import_required_modules()
7646+
76027647
self.database_file = database_file
76037648
self.con = self.connect()
76047649

76057650
import os
76067651
import sys
76077652

7653+
def import_required_modules(self):
7654+
global jpype # noqa PLW0603
7655+
try:
7656+
import jpype # pip install JPype1
7657+
except ModuleNotFoundError as e:
7658+
self.import_failed(e)
7659+
76087660
def connect(self):
76097661
# Get the path to the 'lib' folder
76107662
current_path = os.path.dirname(os.path.abspath(__file__))

0 commit comments

Comments
 (0)