Skip to content

Load model in pytorch c++ #406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
fselka opened this issue May 26, 2021 · 0 comments
Closed

Load model in pytorch c++ #406

fselka opened this issue May 26, 2021 · 0 comments

Comments

@fselka
Copy link

fselka commented May 26, 2021

Hi! I'm trying to load the models in C++ too, but I'm having the following error:

terminate called after throwing an instance of 'torch::jit::ErrorReport' what():
aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor):
Expected at most 12 arguments but found 13 positional arguments.

_ Serialized File "code/torch/torch/nn/modules/conv.py", line 8
def forward(self: torch.torch.nn.modules.conv.Conv2d,
input: Tensor) -> Tensor:
input0 = torch.convolution(input, self.weight, None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1, False, False, True, True)
~~~~~~~~~~~~~~~~~~ <--- HERE
return input0

I used the model trained in cars example (jupyternotebook)

ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['car']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = 'cuda'

model = smp.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

Then I saved it using torch.jit.save as following :

input_img =  torch.rand(1, 3, 256, 256).to(DEVICE)
best_model = torch.load('./best_model.pth')
best_model.eval()
traced_script_module = torch.jit.trace(best_model, input_img)  
traced_script_module.save('./best_model.pt')

When I load the model in c++ using torch::jit::load(model_path.string()); it result in an error.
I'm using Pytorch 1.71

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant