Skip to content

Commit 2c16ae8

Browse files
author
Dan Choi
committed
Fix rcf default batch_size
1 parent 5df7840 commit 2c16ae8

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/sagemaker/amazon/randomcutforest.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,12 @@ def create_model(self):
8787

8888
return RandomCutForestModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session)
8989

90-
def _prepare_for_training(self, records, mini_batch_size=MINI_BATCH_SIZE, job_name=None):
91-
if mini_batch_size != self.MINI_BATCH_SIZE:
92-
raise ValueError("Random Cut Forest uses a fixed mini_batch_size of {}"
93-
.format(self.MINI_BATCH_SIZE))
90+
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
91+
if mini_batch_size is None:
92+
mini_batch_size = self.MINI_BATCH_SIZE
93+
elif mini_batch_size != self.MINI_BATCH_SIZE:
94+
raise ValueError("Random Cut Forest uses a fixed mini_batch_size of {}".format(self.MINI_BATCH_SIZE))
95+
9496
super(RandomCutForest, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)
9597

9698

0 commit comments

Comments
 (0)