You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have the following model_fn where I'm trying to load 2 pth files using torch sagemaker is able to load one .pth file but when coming to the second file it fails I've also tried deploying only the second .pth file but it fails.But the model seems to load when its downloaded from source but when i download it from source and make it fetch from s3 it fails again.Both the files are in model_dir
Here is my model_fn
def model_fn(model_dir):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print("Model Loading...")
det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
det_model_path = os.path.join(model_dir, 'db_resnet50-ac60cadc.pt')
det_params = torch.load(det_model_path, map_location=device)
det_model.load_state_dict(det_params)
print("loading second file")
model = crnn_vgg16_bn(pretrained=True, pretrained_backbone=False)
#failing on this step
rec_model_path = os.path.join(model_dir, 'crnn_vgg16_bn-9762b0b0.pt')
reco_params = torch.load(rec_model_path, map_location=device)
reco_model.load_state_dict(reco_params)
model = ocr_predictor(det_arch=det_model, reco_arch=reco_model, pretrained=False)
model.to(device=device)
print("model_loaded")
return model
The text was updated successfully, but these errors were encountered:
Can you try the way it is demonstrated in this notebook? But heads up, it uses async inference and doesnot use HuggingFaceModel. However you can use the similar approach for realtime inference and still download models from HuggingFace to deploy them to SageMaker.
Uh oh!
There was an error while loading. Please reload this page.
I have the following model_fn where I'm trying to load 2 pth files using torch sagemaker is able to load one .pth file but when coming to the second file it fails I've also tried deploying only the second .pth file but it fails.But the model seems to load when its downloaded from source but when i download it from source and make it fetch from s3 it fails again.Both the files are in model_dir
Here is my model_fn
def model_fn(model_dir):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print("Model Loading...")
det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
det_model_path = os.path.join(model_dir, 'db_resnet50-ac60cadc.pt')
det_params = torch.load(det_model_path, map_location=device)
det_model.load_state_dict(det_params)
print("loading second file")
model = crnn_vgg16_bn(pretrained=True, pretrained_backbone=False)
#failing on this step
rec_model_path = os.path.join(model_dir, 'crnn_vgg16_bn-9762b0b0.pt')
reco_params = torch.load(rec_model_path, map_location=device)
reco_model.load_state_dict(reco_params)
model = ocr_predictor(det_arch=det_model, reco_arch=reco_model, pretrained=False)
model.to(device=device)
print("model_loaded")
return model
The text was updated successfully, but these errors were encountered: