abhanupr
Hi,
I think the issue with the input shape and the axis you’re concatenating on: (in pytorch_concat.py
)
- The input shape for the dummy tensor should be (1, 3, 400, 640) (batch size, channels, height, width):
X = torch.ones((1, 3, 400, 640), dtype=torch.float16)
- You’re concatenating on the wrong axis. Use axis 3 (width) instead of 1 (channel):
return torch.cat((img1, img2), 3)
This should fix the problem!