Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

index error at pose_utils.py save_image function on my own dataset #32

Open
2017494 opened this issue Nov 3, 2020 · 1 comment
Open

Comments

@2017494
Copy link

2017494 commented Nov 3, 2020

started training on my own dataset. Faced a size mismatch error while concatenating in the function save_image. The cloth parse png image is given 2 additional channels as it has 1 channel in the start in format (batchsize,1,256,192) => (batchsize,3,256,192)
But it does not shift the channel to the end as required by the pytorch models. Hence concatenation error occures when there is a size mismatch, example concat([batchsize,256,192,3],[batchsize,3,256,192])

I did add a permute statement by myself but the error should not have occured in the first place.

Code sample:

all parsing_vis and mask can not use tensor type

def save_img(images, path):
img = []
assert len(images) > 0

for i in range(len(images)):
    
    if isinstance(images[i], np.ndarray):
        if images[i].shape[3] == 1:
            images[i] = images[i].repeat(3, axis=3)
        elif images[i].shape[3] != 3:
            images[i] = images[i].transpose((0,2,3,1))
    else:
        if images[i].shape[1] == 1:
            images[i] = images[i].repeat(1,3,1,1)
            **#line changed below**
            **images[i] = images[i].permute(0,2,3,1).cpu().numpy()**
        elif images[i].shape[1] == 3:
            images[i] = ((images[i].permute(0,2,3,1).contiguous().cpu().numpy() * 0.5 + 0.5) * 255).astype(np.uint8)
        else:
            images[i] = images[i].cpu().numpy()
for i in range(len(images[0])):
    img.append(np.concatenate([image[i] for image in images], axis=1))

img = np.concatenate(img, axis=0)

image = Image.fromarray(img.astype(np.uint8))
image.save(path)
@2017494
Copy link
Author

2017494 commented Nov 3, 2020

And can you explain this line as well: images[i] = ((images[i].permute(0,2,3,1).contiguous().cpu().numpy() * 0.5 + 0.5) * 255).astype(np.uint8)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant