Skip to content

Commit 47f5b96

Browse files
authored
Merge branch 'master' into async-slack-tests
2 parents faa4763 + 6abfc7c commit 47f5b96

File tree

7 files changed

+72
-61
lines changed

7 files changed

+72
-61
lines changed

changelog/4458.misc.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Rename ``kwargs`` to ``additional_arguments`` in ``rasa.train`` and ``rasa.core.train`` to make the name of the
2+
argument less confusing.

rasa/cli/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_core(args: argparse.Namespace) -> None:
8585
stories=stories,
8686
endpoints=endpoints,
8787
output=output,
88-
kwargs=vars(args),
88+
additional_arguments=vars(args),
8989
)
9090

9191
else:

rasa/cli/train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def train(args: argparse.Namespace) -> Optional[Text]:
7373
force_training=args.force,
7474
fixed_model_name=args.fixed_model_name,
7575
persist_nlu_training_data=args.persist_nlu_data,
76-
kwargs=extract_additional_arguments(args),
76+
additional_arguments=extract_additional_arguments(args),
7777
)
7878

7979

@@ -92,7 +92,7 @@ def train_core(
9292
story_file = get_validated_path(
9393
args.stories, "stories", DEFAULT_DATA_PATH, none_is_valid=True
9494
)
95-
kwargs = extract_additional_arguments(args)
95+
additional_arguments = extract_additional_arguments(args)
9696

9797
# Policies might be a list for the compare training. Do normal training
9898
# if only list item was passed.
@@ -109,12 +109,14 @@ def train_core(
109109
output=output,
110110
train_path=train_path,
111111
fixed_model_name=args.fixed_model_name,
112-
kwargs=kwargs,
112+
additional_arguments=additional_arguments,
113113
)
114114
else:
115115
from rasa.core.train import do_compare_training
116116

117-
loop.run_until_complete(do_compare_training(args, story_file, kwargs))
117+
loop.run_until_complete(
118+
do_compare_training(args, story_file, additional_arguments)
119+
)
118120

119121

120122
def train_nlu(

rasa/core/train.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ async def train(
2929
dump_stories: bool = False,
3030
policy_config: Optional[Union[Text, Dict]] = None,
3131
exclusion_percentage: int = None,
32-
kwargs: Optional[Dict] = None,
32+
additional_arguments: Optional[Dict] = None,
3333
):
3434
from rasa.core.agent import Agent
3535
from rasa.core import config, utils
@@ -38,8 +38,8 @@ async def train(
3838
if not endpoints:
3939
endpoints = AvailableEndpoints()
4040

41-
if not kwargs:
42-
kwargs = {}
41+
if not additional_arguments:
42+
additional_arguments = {}
4343

4444
policies = config.load(policy_config)
4545

@@ -51,8 +51,8 @@ async def train(
5151
policies=policies,
5252
)
5353

54-
data_load_args, kwargs = utils.extract_args(
55-
kwargs,
54+
data_load_args, additional_arguments = utils.extract_args(
55+
additional_arguments,
5656
{
5757
"use_story_concatenation",
5858
"unique_last_num_states",
@@ -64,7 +64,7 @@ async def train(
6464
training_data = await agent.load_data(
6565
training_resource, exclusion_percentage=exclusion_percentage, **data_load_args
6666
)
67-
agent.train(training_data, **kwargs)
67+
agent.train(training_data, **additional_arguments)
6868
agent.persist(output_path, dump_stories)
6969

7070
return agent
@@ -78,7 +78,7 @@ async def train_comparison_models(
7878
policy_configs: Optional[List] = None,
7979
runs: int = 1,
8080
dump_stories: bool = False,
81-
kwargs: Optional[Dict] = None,
81+
additional_arguments: Optional[Dict] = None,
8282
):
8383
"""Train multiple models for comparison of policies"""
8484
from rasa import model
@@ -114,7 +114,7 @@ async def train_comparison_models(
114114
train_path,
115115
policy_config=policy_config,
116116
exclusion_percentage=percentage,
117-
kwargs=kwargs,
117+
additional_arguments=additional_arguments,
118118
dump_stories=dump_stories,
119119
),
120120
model.model_fingerprint(file_importer),
@@ -148,14 +148,14 @@ async def do_compare_training(
148148
):
149149
_, no_stories = await asyncio.gather(
150150
train_comparison_models(
151-
story_file,
152-
args.domain,
153-
args.out,
154-
args.percentages,
155-
args.config,
156-
args.runs,
157-
args.dump_stories,
158-
additional_arguments,
151+
story_file=story_file,
152+
domain=args.domain,
153+
output_path=args.out,
154+
exclusion_percentages=args.percentages,
155+
policy_configs=args.config,
156+
runs=args.runs,
157+
dump_stories=args.dump_stories,
158+
additional_arguments=additional_arguments,
159159
),
160160
get_no_of_stories(args.stories, args.domain),
161161
)

rasa/test.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,21 @@ def test(
4040
nlu_data: Text,
4141
endpoints: Optional[Text] = None,
4242
output: Text = DEFAULT_RESULTS_PATH,
43-
kwargs: Optional[Dict] = None,
43+
additional_arguments: Optional[Dict] = None,
4444
):
45-
if kwargs is None:
46-
kwargs = {}
45+
if additional_arguments is None:
46+
additional_arguments = {}
4747

48-
test_core(model, stories, endpoints, output, kwargs)
49-
test_nlu(model, nlu_data, output, kwargs)
48+
test_core(model, stories, endpoints, output, additional_arguments)
49+
test_nlu(model, nlu_data, output, additional_arguments)
5050

5151

5252
def test_core(
5353
model: Optional[Text] = None,
5454
stories: Optional[Text] = None,
5555
endpoints: Optional[Text] = None,
5656
output: Text = DEFAULT_RESULTS_PATH,
57-
kwargs: Optional[Dict] = None,
57+
additional_arguments: Optional[Dict] = None,
5858
):
5959
import rasa.core.test
6060
import rasa.core.utils as core_utils
@@ -64,8 +64,8 @@ def test_core(
6464

6565
_endpoints = core_utils.AvailableEndpoints.read_endpoints(endpoints)
6666

67-
if kwargs is None:
68-
kwargs = {}
67+
if additional_arguments is None:
68+
additional_arguments = {}
6969

7070
if output:
7171
io_utils.create_directory(output)
@@ -87,7 +87,7 @@ def test_core(
8787
"Rasa model and provide it via the '--model' argument."
8888
)
8989

90-
use_e2e = kwargs["e2e"] if "e2e" in kwargs else False
90+
use_e2e = additional_arguments.get("e2e", False)
9191

9292
_interpreter = RegexInterpreter()
9393
if use_e2e:
@@ -101,7 +101,9 @@ def test_core(
101101

102102
_agent = Agent.load(unpacked_model, interpreter=_interpreter)
103103

104-
kwargs = utils.minimal_kwargs(kwargs, rasa.core.test, ["stories", "agent"])
104+
kwargs = utils.minimal_kwargs(
105+
additional_arguments, rasa.core.test, ["stories", "agent"]
106+
)
105107

106108
loop = asyncio.get_event_loop()
107109
loop.run_until_complete(
@@ -113,7 +115,7 @@ def test_nlu(
113115
model: Optional[Text],
114116
nlu_data: Optional[Text],
115117
output_directory: Text = DEFAULT_RESULTS_PATH,
116-
kwargs: Optional[Dict] = None,
118+
additional_arguments: Optional[Dict] = None,
117119
):
118120
from rasa.nlu.test import run_evaluation
119121
from rasa.model import get_model
@@ -132,7 +134,9 @@ def test_nlu(
132134
nlu_model = os.path.join(unpacked_model, "nlu")
133135

134136
if os.path.exists(nlu_model):
135-
kwargs = utils.minimal_kwargs(kwargs, run_evaluation, ["data_path", "model"])
137+
kwargs = utils.minimal_kwargs(
138+
additional_arguments, run_evaluation, ["data_path", "model"]
139+
)
136140
run_evaluation(nlu_data, nlu_model, output_directory=output_directory, **kwargs)
137141
else:
138142
print_error(
@@ -186,7 +190,10 @@ def compare_nlu_models(
186190

187191

188192
def perform_nlu_cross_validation(
189-
config: Text, nlu: Text, output: Text, kwargs: Optional[Dict[Text, Any]]
193+
config: Text,
194+
nlu: Text,
195+
output: Text,
196+
additional_arguments: Optional[Dict[Text, Any]],
190197
):
191198
import rasa.nlu.config
192199
from rasa.nlu.test import (
@@ -196,12 +203,12 @@ def perform_nlu_cross_validation(
196203
return_entity_results,
197204
)
198205

199-
kwargs = kwargs or {}
200-
folds = int(kwargs.get("folds", 3))
206+
additional_arguments = additional_arguments or {}
207+
folds = int(additional_arguments.get("folds", 3))
201208
nlu_config = rasa.nlu.config.load(config)
202209
data = rasa.nlu.training_data.load_data(nlu)
203210
data = drop_intents_below_freq(data, cutoff=folds)
204-
kwargs = utils.minimal_kwargs(kwargs, cross_validate)
211+
kwargs = utils.minimal_kwargs(additional_arguments, cross_validate)
205212
results, entity_results, response_selection_results = cross_validate(
206213
data, folds, nlu_config, output, **kwargs
207214
)

rasa/train.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def train(
2828
force_training: bool = False,
2929
fixed_model_name: Optional[Text] = None,
3030
persist_nlu_training_data: bool = False,
31-
kwargs: Optional[Dict] = None,
31+
additional_arguments: Optional[Dict] = None,
3232
loop: Optional[asyncio.AbstractEventLoop] = None,
3333
) -> Optional[Text]:
3434
if loop is None:
@@ -43,7 +43,7 @@ def train(
4343
force_training=force_training,
4444
fixed_model_name=fixed_model_name,
4545
persist_nlu_training_data=persist_nlu_training_data,
46-
kwargs=kwargs,
46+
additional_arguments=additional_arguments,
4747
)
4848
)
4949

@@ -56,7 +56,7 @@ async def train_async(
5656
force_training: bool = False,
5757
fixed_model_name: Optional[Text] = None,
5858
persist_nlu_training_data: bool = False,
59-
kwargs: Optional[Dict] = None,
59+
additional_arguments: Optional[Dict] = None,
6060
) -> Optional[Text]:
6161
"""Trains a Rasa model (Core and NLU).
6262
@@ -69,7 +69,7 @@ async def train_async(
6969
fixed_model_name: Name of model to be stored.
7070
persist_nlu_training_data: `True` if the NLU training data should be persisted
7171
with the model.
72-
kwargs: Additional training parameters.
72+
additional_arguments: Additional training parameters.
7373
7474
Returns:
7575
Path of the trained model archive.
@@ -94,7 +94,7 @@ async def train_async(
9494
force_training,
9595
fixed_model_name,
9696
persist_nlu_training_data,
97-
kwargs,
97+
additional_arguments,
9898
)
9999

100100

@@ -118,7 +118,7 @@ async def _train_async_internal(
118118
force_training: bool,
119119
fixed_model_name: Optional[Text],
120120
persist_nlu_training_data: bool,
121-
kwargs: Optional[Dict],
121+
additional_arguments: Optional[Dict],
122122
) -> Optional[Text]:
123123
"""Trains a Rasa model (Core and NLU). Use only from `train_async`.
124124
@@ -130,7 +130,7 @@ async def _train_async_internal(
130130
persist_nlu_training_data: `True` if the NLU training data should be persisted
131131
with the model.
132132
fixed_model_name: Name of model to be stored.
133-
kwargs: Additional training parameters.
133+
additional_arguments: Additional training parameters.
134134
135135
Returns:
136136
Path of the trained model archive.
@@ -162,7 +162,7 @@ async def _train_async_internal(
162162
file_importer,
163163
output=output_path,
164164
fixed_model_name=fixed_model_name,
165-
kwargs=kwargs,
165+
additional_arguments=additional_arguments,
166166
)
167167

168168
new_fingerprint = await model.model_fingerprint(file_importer)
@@ -181,7 +181,7 @@ async def _train_async_internal(
181181
fingerprint_comparison_result=fingerprint_comparison,
182182
fixed_model_name=fixed_model_name,
183183
persist_nlu_training_data=persist_nlu_training_data,
184-
kwargs=kwargs,
184+
additional_arguments=additional_arguments,
185185
)
186186

187187
return model.package_model(
@@ -205,7 +205,7 @@ async def _do_training(
205205
fingerprint_comparison_result: Optional[FingerprintComparisonResult] = None,
206206
fixed_model_name: Optional[Text] = None,
207207
persist_nlu_training_data: bool = False,
208-
kwargs: Optional[Dict] = None,
208+
additional_arguments: Optional[Dict] = None,
209209
):
210210
if not fingerprint_comparison_result:
211211
fingerprint_comparison_result = FingerprintComparisonResult()
@@ -216,7 +216,7 @@ async def _do_training(
216216
output=output_path,
217217
train_path=train_path,
218218
fixed_model_name=fixed_model_name,
219-
kwargs=kwargs,
219+
additional_arguments=additional_arguments,
220220
)
221221
elif fingerprint_comparison_result.should_retrain_nlg():
222222
print_color(
@@ -254,7 +254,7 @@ def train_core(
254254
output: Text,
255255
train_path: Optional[Text] = None,
256256
fixed_model_name: Optional[Text] = None,
257-
kwargs: Optional[Dict] = None,
257+
additional_arguments: Optional[Dict] = None,
258258
) -> Optional[Text]:
259259
loop = asyncio.get_event_loop()
260260
return loop.run_until_complete(
@@ -265,7 +265,7 @@ def train_core(
265265
output=output,
266266
train_path=train_path,
267267
fixed_model_name=fixed_model_name,
268-
kwargs=kwargs,
268+
additional_arguments=additional_arguments,
269269
)
270270
)
271271

@@ -277,7 +277,7 @@ async def train_core_async(
277277
output: Text,
278278
train_path: Optional[Text] = None,
279279
fixed_model_name: Optional[Text] = None,
280-
kwargs: Optional[Dict] = None,
280+
additional_arguments: Optional[Dict] = None,
281281
) -> Optional[Text]:
282282
"""Trains a Core model.
283283
@@ -290,7 +290,7 @@ async def train_core_async(
290290
directory, otherwise in the provided directory.
291291
fixed_model_name: Name of model to be stored.
292292
uncompress: If `True` the model will not be compressed.
293-
kwargs: Additional training parameters.
293+
additional_arguments: Additional training parameters.
294294
295295
Returns:
296296
If `train_path` is given it returns the path to the model archive,
@@ -321,7 +321,7 @@ async def train_core_async(
321321
output=output,
322322
train_path=train_path,
323323
fixed_model_name=fixed_model_name,
324-
kwargs=kwargs,
324+
additional_arguments=additional_arguments,
325325
)
326326

327327

@@ -330,7 +330,7 @@ async def _train_core_with_validated_data(
330330
output: Text,
331331
train_path: Optional[Text] = None,
332332
fixed_model_name: Optional[Text] = None,
333-
kwargs: Optional[Dict] = None,
333+
additional_arguments: Optional[Dict] = None,
334334
) -> Optional[Text]:
335335
"""Train Core with validated training and config data."""
336336

@@ -354,7 +354,7 @@ async def _train_core_with_validated_data(
354354
training_resource=file_importer,
355355
output_path=os.path.join(_train_path, DEFAULT_CORE_SUBDIRECTORY_NAME),
356356
policy_config=config,
357-
kwargs=kwargs,
357+
additional_arguments=additional_arguments,
358358
)
359359
print_color("Core model training completed.", color=bcolors.OKBLUE)
360360

0 commit comments

Comments
 (0)