Skip to content

Commit 242d2e2

Browse files
author
Stephen Hoover
authored
ENH Register models trained outside of Civis Platform (civisanalytics#242)
If you train a scikit-learn compatible estimator outside of Civis Platform, you can use this to upload it to Civis Platform and prepare it for scoring with CivisML. There's a new Custom Script which will introspect metadata necessary for CivisML and make itself appear sufficiently like a CivisML training job that it can be used as input to a scoring job.
1 parent ee2947f commit 242d2e2

File tree

3 files changed

+187
-8
lines changed

3 files changed

+187
-8
lines changed

civis/ml/_model.py

+150-3
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@
4242
9112: 9113, # v1.1
4343
8387: 9113, # v1.0
4444
7020: 7021, # v0.5
45+
11028: 10616, # v2.2 registration CHANGE ME
4546
}
4647
_CIVISML_TEMPLATE = None # CivisML training template to use
48+
REGISTRATION_TEMPLATES = [11028, # v2.2 CHANGE ME
49+
]
4750

4851

4952
class ModelError(RuntimeError):
@@ -631,10 +634,10 @@ class ModelPipeline:
631634
See :func:`~civis.resources._resources.Scripts.post_custom` for
632635
further documentation about email and URL notification.
633636
dependencies : array, optional
634-
List of packages to install from PyPI or git repository (i.e., Github
637+
List of packages to install from PyPI or git repository (e.g., Github
635638
or Bitbucket). If a private repo is specified, please include a
636639
``git_token_name`` argument as well (see below). Make sure to pin
637-
dependencies to a specific version, since dependecies will be
640+
dependencies to a specific version, since dependencies will be
638641
reinstalled during every training and predict job.
639642
git_token_name : str, optional
640643
Name of remote git API token stored in Civis Platform as the password
@@ -713,6 +716,8 @@ def _get_template_ids(self, client):
713716
global _CIVISML_TEMPLATE
714717
if _CIVISML_TEMPLATE is None:
715718
for t_id in sorted(_PRED_TEMPLATES)[::-1]:
719+
if t_id in REGISTRATION_TEMPLATES:
720+
continue
716721
try:
717722
# Check that we can access the template
718723
client.templates.get_scripts(id=t_id)
@@ -783,6 +788,147 @@ def __setstate__(self, state):
783788
template_ids = self._get_template_ids(self._client)
784789
self.train_template_id, self.predict_template_id = template_ids
785790

791+
@classmethod
792+
def register_pretrained_model(cls, model, dependent_variable=None,
793+
features=None, primary_key=None,
794+
model_name=None, dependencies=None,
795+
git_token_name=None,
796+
skip_model_check=False, verbose=False,
797+
client=None):
798+
"""Use a fitted scikit-learn model with CivisML scoring
799+
800+
Use this function to set up your own fitted scikit-learn-compatible
801+
Estimator object for scoring with CivisML. This function will
802+
upload your model to Civis Platform and store enough metadata
803+
about it that you can subsequently use it with a CivisML scoring job.
804+
805+
The only required input is the model itself, but you are strongly
806+
recommended to also provide a list of feature names. Without a list
807+
of feature names, CivisML will have to assume that your scoring
808+
table contains only the features needed for scoring (perhaps also
809+
with a primary key column), in all in the correct order.
810+
811+
Parameters
812+
----------
813+
model : sklearn.base.BaseEstimator or int
814+
The model object. This must be a fitted scikit-learn compatible
815+
Estimator object, or else the integer Civis File ID of a
816+
pickle or joblib-serialized file which stores such an object.
817+
dependent_variable : string or List[str], optional
818+
The dependent variable of the training dataset.
819+
For a multi-target problem, this should be a list of
820+
column names of dependent variables.
821+
features : string or List[str], optional
822+
A list of column names of features which were used for training.
823+
These will be used to ensure that tables input for prediction
824+
have the correct features in the correct order.
825+
primary_key : string, optional
826+
The unique ID (primary key) of the scoring dataset
827+
model_name : string, optional
828+
The name of the Platform registration job. It will have
829+
" Predict" added to become the Script title for predictions.
830+
dependencies : array, optional
831+
List of packages to install from PyPI or git repository (e.g.,
832+
GitHub or Bitbucket). If a private repo is specified, please
833+
include a ``git_token_name`` argument as well (see below).
834+
Make sure to pin dependencies to a specific version, since
835+
dependencies will be reinstalled during every predict job.
836+
git_token_name : str, optional
837+
Name of remote git API token stored in Civis Platform as
838+
the password field in a custom platform credential.
839+
Used only when installing private git repositories.
840+
skip_model_check : bool, optional
841+
If you're sure that your model will work with CivisML, but it
842+
will fail the comprehensive verification, set this to True.
843+
verbose : bool, optional
844+
If True, supply debug outputs in Platform logs and make
845+
prediction child jobs visible.
846+
client : :class:`~civis.APIClient`, optional
847+
If not provided, an :class:`~civis.APIClient` object will be
848+
created from the :envvar:`CIVIS_API_KEY`.
849+
850+
Returns
851+
-------
852+
:class:`~civis.ml.ModelPipeline`
853+
854+
Examples
855+
--------
856+
This example assumes that you already have training data
857+
``X`` and ``y``, where ``X`` is a :class:`~pandas.DataFrame`.
858+
>>> from civis.ml import ModelPipeline
859+
>>> from sklearn.linear_model import Lasso
860+
>>> est = Lasso().fit(X, y)
861+
>>> model = ModelPipeline.register_pretrained_model(
862+
... est, 'concrete', features=X.columns)
863+
>>> model.predict(table_name='my.table', database_name='my-db')
864+
"""
865+
client = client or APIClient()
866+
867+
if isinstance(dependent_variable, six.string_types):
868+
dependent_variable = [dependent_variable]
869+
if isinstance(features, six.string_types):
870+
features = [features]
871+
if isinstance(dependencies, six.string_types):
872+
dependencies = [dependencies]
873+
if not model_name:
874+
model_name = ("Pretrained {} model for "
875+
"CivisML".format(model.__class__.__name__))
876+
model_name = model_name[:255] # Max size is 255 characters
877+
878+
if isinstance(model, (int, float, six.string_types)):
879+
model_file_id = int(model)
880+
else:
881+
try:
882+
tempdir = tempfile.mkdtemp()
883+
fout = os.path.join(tempdir, 'model_for_civisml.pkl')
884+
joblib.dump(model, fout, compress=3)
885+
with open(fout, 'rb') as _fout:
886+
# NB: Using the name "estimator.pkl" means that
887+
# CivisML doesn't need to copy this input to a file
888+
# with a different name.
889+
model_file_id = cio.file_to_civis(_fout, 'estimator.pkl',
890+
client=client)
891+
finally:
892+
shutil.rmtree(tempdir)
893+
894+
args = {'MODEL_FILE_ID': str(model_file_id),
895+
'SKIP_MODEL_CHECK': skip_model_check,
896+
'DEBUG': verbose}
897+
if dependent_variable is not None:
898+
args['TARGET_COLUMN'] = ' '.join(dependent_variable)
899+
if features is not None:
900+
args['FEATURE_COLUMNS'] = ' '.join(features)
901+
if dependencies is not None:
902+
args['DEPENDENCIES'] = ' '.join(dependencies)
903+
if git_token_name:
904+
creds = find(client.credentials.list(),
905+
name=git_token_name,
906+
type='Custom')
907+
if len(creds) > 1:
908+
raise ValueError("Unique credential with name '{}' for "
909+
"remote git hosting service not found!"
910+
.format(git_token_name))
911+
args['GIT_CRED'] = creds[0].id
912+
913+
template_id = max(REGISTRATION_TEMPLATES)
914+
container = client.scripts.post_custom(
915+
from_template_id=template_id,
916+
name=model_name,
917+
arguments=args)
918+
log.info('Created custom script %s.', container.id)
919+
920+
run = client.scripts.post_custom_runs(container.id)
921+
log.debug('Started job %s, run %s.', container.id, run.id)
922+
923+
fut = ModelFuture(container.id, run.id, client=client,
924+
poll_on_creation=False)
925+
fut.result()
926+
log.info('Model registration complete.')
927+
928+
mp = ModelPipeline.from_existing(fut.job_id, fut.run_id, client)
929+
mp.primary_key = primary_key
930+
return mp
931+
786932
@classmethod
787933
def from_existing(cls, train_job_id, train_run_id='latest', client=None):
788934
"""Create a :class:`ModelPipeline` object from existing model IDs
@@ -887,7 +1033,8 @@ def from_existing(cls, train_job_id, train_run_id='latest', client=None):
8871033
'prediction code. Prediction will either fail '
8881034
'immediately or succeed.'
8891035
% (train_job_id, __version__), RuntimeWarning)
890-
p_id = max(_PRED_TEMPLATES.values())
1036+
p_id = max([v for k, v in _PRED_TEMPLATES.items()
1037+
if k not in REGISTRATION_TEMPLATES])
8911038
klass.predict_template_id = p_id
8921039

8931040
return klass

civis/ml/tests/test_model.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
from civis.ml import _model
4040

4141

42+
LATEST_TRAIN_TEMPLATE = 10582
43+
LATEST_PRED_TEMPLATE = 10583
44+
45+
4246
def setup_client_mock(script_id=-10, run_id=100, state='succeeded',
4347
run_outputs=None):
4448
"""Return a Mock set up for use in testing container scripts
@@ -682,7 +686,7 @@ def test_modelpipeline_init_newest():
682686
mp = _model.ModelPipeline(LogisticRegression(), 'test', etl=etl,
683687
client=mock_client)
684688
assert mp.etl == etl
685-
assert mp.train_template_id == max(_model._PRED_TEMPLATES)
689+
assert mp.train_template_id == LATEST_TRAIN_TEMPLATE
686690
# clean up
687691
_model._CIVISML_TEMPLATE = None
688692

@@ -787,16 +791,15 @@ def test_modelpipeline_classmethod_constructor_defaults(
787791
def test_modelpipeline_classmethod_constructor_future_train_version():
788792
# Test handling attempts to restore a model created with a newer
789793
# version of CivisML.
790-
current_max_template = max(_model._PRED_TEMPLATES)
791-
cont = container_response_stub(current_max_template + 1000)
794+
cont = container_response_stub(LATEST_TRAIN_TEMPLATE + 1000)
792795
mock_client = mock.Mock()
793796
mock_client.scripts.get_containers.return_value = cont
794797
mock_client.credentials.get.return_value = Response({'name': 'Token'})
795798

796799
# test everything is working fine
797800
with pytest.warns(RuntimeWarning):
798801
mp = _model.ModelPipeline.from_existing(1, 1, client=mock_client)
799-
exp_p_id = _model._PRED_TEMPLATES[current_max_template]
802+
exp_p_id = _model._PRED_TEMPLATES[LATEST_TRAIN_TEMPLATE]
800803
assert mp.predict_template_id == exp_p_id
801804

802805

@@ -892,7 +895,7 @@ def test_modelpipeline_train_df(mock_ccr, mock_stash, mp_setup):
892895
train_data = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
893896
assert 'res' == mp.train(train_data)
894897
mock_stash.assert_called_once_with(
895-
train_data, max(_model._PRED_TEMPLATES.keys()), client=mock.ANY)
898+
train_data, LATEST_TRAIN_TEMPLATE, client=mock.ANY)
896899
assert mp.train_result_ == 'res'
897900

898901

docs/source/ml.rst

+29
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ or by providing your own scikit-learn
6868
Note that whichever option you chose, CivisML will pre-process your
6969
data using either its default ETL, or ETL that you provide (see :ref:`custom-etl`).
7070

71+
If you have already trained a scikit-learn model outside of Civis Platform,
72+
you can register it with Civis Platform as a CivisML model so that you can
73+
score it using CivisML. Read :ref:`model-registration` for how to do this.
7174

7275
Pre-Defined Models
7376
------------------
@@ -359,6 +362,32 @@ for solving a problem. For example:
359362
train = [model.train(table_name='schema.name', database_name='My DB') for model in models]
360363
aucs = [tr.metrics['roc_auc'] for tr in train] # Code blocks here
361364
365+
.. _model-registration:
366+
367+
Registering Models Trained Outside of Civis
368+
===========================================
369+
370+
Instead of using CivisML to train your model, you may train any
371+
scikit-learn-compatible model outside of Civis Platform and use
372+
:meth:`civis.ml.ModelPipeline.register_pretrained_model` to register it
373+
as a CivisML model in Civis Platform. This will let you use Civis Platform
374+
to make predictions using your model, either to take advantage of distributed
375+
predictions on large datasets, or to create predictions as part of
376+
a workflow or service in Civis Platform.
377+
378+
When registering a model trained outside of Civis Platform, you are
379+
strongly advised to provide an ordered list of feature names used
380+
for training. This will allow CivisML to ensure that tables of data
381+
input for predictions have the correct features in the correct order.
382+
If your model has more than one output, you should also provide a list
383+
of output names so that CivisML knows how many outputs to expect and
384+
how to name them in the resulting table of model predictions.
385+
386+
If your model uses dependencies which aren't part of the default CivisML
387+
execution environment, you must provide them to the ``dependencies``
388+
parameter of the :meth:`~civis.ml.ModelPipeline.register_pretrained_model`
389+
function, just as with the :class:`~civis.ml.ModelPipeline` constructor.
390+
362391

363392
Object reference
364393
================

0 commit comments

Comments
 (0)