|
42 | 42 | 9112: 9113, # v1.1
|
43 | 43 | 8387: 9113, # v1.0
|
44 | 44 | 7020: 7021, # v0.5
|
| 45 | + 11028: 10616, # v2.2 registration CHANGE ME |
45 | 46 | }
|
46 | 47 | _CIVISML_TEMPLATE = None # CivisML training template to use
|
| 48 | +REGISTRATION_TEMPLATES = [11028, # v2.2 CHANGE ME |
| 49 | + ] |
47 | 50 |
|
48 | 51 |
|
49 | 52 | class ModelError(RuntimeError):
|
@@ -631,10 +634,10 @@ class ModelPipeline:
|
631 | 634 | See :func:`~civis.resources._resources.Scripts.post_custom` for
|
632 | 635 | further documentation about email and URL notification.
|
633 | 636 | dependencies : array, optional
|
634 |
| - List of packages to install from PyPI or git repository (i.e., Github |
| 637 | + List of packages to install from PyPI or git repository (e.g., Github |
635 | 638 | or Bitbucket). If a private repo is specified, please include a
|
636 | 639 | ``git_token_name`` argument as well (see below). Make sure to pin
|
637 |
| - dependencies to a specific version, since dependecies will be |
| 640 | + dependencies to a specific version, since dependencies will be |
638 | 641 | reinstalled during every training and predict job.
|
639 | 642 | git_token_name : str, optional
|
640 | 643 | Name of remote git API token stored in Civis Platform as the password
|
@@ -713,6 +716,8 @@ def _get_template_ids(self, client):
|
713 | 716 | global _CIVISML_TEMPLATE
|
714 | 717 | if _CIVISML_TEMPLATE is None:
|
715 | 718 | for t_id in sorted(_PRED_TEMPLATES)[::-1]:
|
| 719 | + if t_id in REGISTRATION_TEMPLATES: |
| 720 | + continue |
716 | 721 | try:
|
717 | 722 | # Check that we can access the template
|
718 | 723 | client.templates.get_scripts(id=t_id)
|
@@ -783,6 +788,147 @@ def __setstate__(self, state):
|
783 | 788 | template_ids = self._get_template_ids(self._client)
|
784 | 789 | self.train_template_id, self.predict_template_id = template_ids
|
785 | 790 |
|
| 791 | + @classmethod |
| 792 | + def register_pretrained_model(cls, model, dependent_variable=None, |
| 793 | + features=None, primary_key=None, |
| 794 | + model_name=None, dependencies=None, |
| 795 | + git_token_name=None, |
| 796 | + skip_model_check=False, verbose=False, |
| 797 | + client=None): |
| 798 | + """Use a fitted scikit-learn model with CivisML scoring |
| 799 | +
|
| 800 | + Use this function to set up your own fitted scikit-learn-compatible |
| 801 | + Estimator object for scoring with CivisML. This function will |
| 802 | + upload your model to Civis Platform and store enough metadata |
| 803 | + about it that you can subsequently use it with a CivisML scoring job. |
| 804 | +
|
| 805 | + The only required input is the model itself, but you are strongly |
| 806 | + recommended to also provide a list of feature names. Without a list |
| 807 | + of feature names, CivisML will have to assume that your scoring |
| 808 | + table contains only the features needed for scoring (perhaps also |
| 809 | + with a primary key column), in all in the correct order. |
| 810 | +
|
| 811 | + Parameters |
| 812 | + ---------- |
| 813 | + model : sklearn.base.BaseEstimator or int |
| 814 | + The model object. This must be a fitted scikit-learn compatible |
| 815 | + Estimator object, or else the integer Civis File ID of a |
| 816 | + pickle or joblib-serialized file which stores such an object. |
| 817 | + dependent_variable : string or List[str], optional |
| 818 | + The dependent variable of the training dataset. |
| 819 | + For a multi-target problem, this should be a list of |
| 820 | + column names of dependent variables. |
| 821 | + features : string or List[str], optional |
| 822 | + A list of column names of features which were used for training. |
| 823 | + These will be used to ensure that tables input for prediction |
| 824 | + have the correct features in the correct order. |
| 825 | + primary_key : string, optional |
| 826 | + The unique ID (primary key) of the scoring dataset |
| 827 | + model_name : string, optional |
| 828 | + The name of the Platform registration job. It will have |
| 829 | + " Predict" added to become the Script title for predictions. |
| 830 | + dependencies : array, optional |
| 831 | + List of packages to install from PyPI or git repository (e.g., |
| 832 | + GitHub or Bitbucket). If a private repo is specified, please |
| 833 | + include a ``git_token_name`` argument as well (see below). |
| 834 | + Make sure to pin dependencies to a specific version, since |
| 835 | + dependencies will be reinstalled during every predict job. |
| 836 | + git_token_name : str, optional |
| 837 | + Name of remote git API token stored in Civis Platform as |
| 838 | + the password field in a custom platform credential. |
| 839 | + Used only when installing private git repositories. |
| 840 | + skip_model_check : bool, optional |
| 841 | + If you're sure that your model will work with CivisML, but it |
| 842 | + will fail the comprehensive verification, set this to True. |
| 843 | + verbose : bool, optional |
| 844 | + If True, supply debug outputs in Platform logs and make |
| 845 | + prediction child jobs visible. |
| 846 | + client : :class:`~civis.APIClient`, optional |
| 847 | + If not provided, an :class:`~civis.APIClient` object will be |
| 848 | + created from the :envvar:`CIVIS_API_KEY`. |
| 849 | +
|
| 850 | + Returns |
| 851 | + ------- |
| 852 | + :class:`~civis.ml.ModelPipeline` |
| 853 | +
|
| 854 | + Examples |
| 855 | + -------- |
| 856 | + This example assumes that you already have training data |
| 857 | + ``X`` and ``y``, where ``X`` is a :class:`~pandas.DataFrame`. |
| 858 | + >>> from civis.ml import ModelPipeline |
| 859 | + >>> from sklearn.linear_model import Lasso |
| 860 | + >>> est = Lasso().fit(X, y) |
| 861 | + >>> model = ModelPipeline.register_pretrained_model( |
| 862 | + ... est, 'concrete', features=X.columns) |
| 863 | + >>> model.predict(table_name='my.table', database_name='my-db') |
| 864 | + """ |
| 865 | + client = client or APIClient() |
| 866 | + |
| 867 | + if isinstance(dependent_variable, six.string_types): |
| 868 | + dependent_variable = [dependent_variable] |
| 869 | + if isinstance(features, six.string_types): |
| 870 | + features = [features] |
| 871 | + if isinstance(dependencies, six.string_types): |
| 872 | + dependencies = [dependencies] |
| 873 | + if not model_name: |
| 874 | + model_name = ("Pretrained {} model for " |
| 875 | + "CivisML".format(model.__class__.__name__)) |
| 876 | + model_name = model_name[:255] # Max size is 255 characters |
| 877 | + |
| 878 | + if isinstance(model, (int, float, six.string_types)): |
| 879 | + model_file_id = int(model) |
| 880 | + else: |
| 881 | + try: |
| 882 | + tempdir = tempfile.mkdtemp() |
| 883 | + fout = os.path.join(tempdir, 'model_for_civisml.pkl') |
| 884 | + joblib.dump(model, fout, compress=3) |
| 885 | + with open(fout, 'rb') as _fout: |
| 886 | + # NB: Using the name "estimator.pkl" means that |
| 887 | + # CivisML doesn't need to copy this input to a file |
| 888 | + # with a different name. |
| 889 | + model_file_id = cio.file_to_civis(_fout, 'estimator.pkl', |
| 890 | + client=client) |
| 891 | + finally: |
| 892 | + shutil.rmtree(tempdir) |
| 893 | + |
| 894 | + args = {'MODEL_FILE_ID': str(model_file_id), |
| 895 | + 'SKIP_MODEL_CHECK': skip_model_check, |
| 896 | + 'DEBUG': verbose} |
| 897 | + if dependent_variable is not None: |
| 898 | + args['TARGET_COLUMN'] = ' '.join(dependent_variable) |
| 899 | + if features is not None: |
| 900 | + args['FEATURE_COLUMNS'] = ' '.join(features) |
| 901 | + if dependencies is not None: |
| 902 | + args['DEPENDENCIES'] = ' '.join(dependencies) |
| 903 | + if git_token_name: |
| 904 | + creds = find(client.credentials.list(), |
| 905 | + name=git_token_name, |
| 906 | + type='Custom') |
| 907 | + if len(creds) > 1: |
| 908 | + raise ValueError("Unique credential with name '{}' for " |
| 909 | + "remote git hosting service not found!" |
| 910 | + .format(git_token_name)) |
| 911 | + args['GIT_CRED'] = creds[0].id |
| 912 | + |
| 913 | + template_id = max(REGISTRATION_TEMPLATES) |
| 914 | + container = client.scripts.post_custom( |
| 915 | + from_template_id=template_id, |
| 916 | + name=model_name, |
| 917 | + arguments=args) |
| 918 | + log.info('Created custom script %s.', container.id) |
| 919 | + |
| 920 | + run = client.scripts.post_custom_runs(container.id) |
| 921 | + log.debug('Started job %s, run %s.', container.id, run.id) |
| 922 | + |
| 923 | + fut = ModelFuture(container.id, run.id, client=client, |
| 924 | + poll_on_creation=False) |
| 925 | + fut.result() |
| 926 | + log.info('Model registration complete.') |
| 927 | + |
| 928 | + mp = ModelPipeline.from_existing(fut.job_id, fut.run_id, client) |
| 929 | + mp.primary_key = primary_key |
| 930 | + return mp |
| 931 | + |
786 | 932 | @classmethod
|
787 | 933 | def from_existing(cls, train_job_id, train_run_id='latest', client=None):
|
788 | 934 | """Create a :class:`ModelPipeline` object from existing model IDs
|
@@ -887,7 +1033,8 @@ def from_existing(cls, train_job_id, train_run_id='latest', client=None):
|
887 | 1033 | 'prediction code. Prediction will either fail '
|
888 | 1034 | 'immediately or succeed.'
|
889 | 1035 | % (train_job_id, __version__), RuntimeWarning)
|
890 |
| - p_id = max(_PRED_TEMPLATES.values()) |
| 1036 | + p_id = max([v for k, v in _PRED_TEMPLATES.items() |
| 1037 | + if k not in REGISTRATION_TEMPLATES]) |
891 | 1038 | klass.predict_template_id = p_id
|
892 | 1039 |
|
893 | 1040 | return klass
|
|
0 commit comments