Why is my image tensor output all zeros after processing with Spatial Transformer Network Pytorch?

Why is my image tensor output all zeros after processing with Spatial Transformer Network Pytorch?

  

I am working on a small project involving a Spatial Transformer Network (STN) to process images. I accidentally uploaded a branch with untested code, and now I'm facing an issue where my image tensor output is all zeros.

Model Preview

Here's the relevant part of my code:

import torch
from torchvision import transforms
from PIL import Image
from pathlib import Path
from model import STModel
from typing import Union
import numpy as np

class STN:
    """
    Class to handle the processing of a single image using a Spatial Transformer Network (STN).

    Args:
        pretrained (Path): Path to the pre-trained model.
    """

    def __init__(self, pretrained: Union[str, Path]) -> None:
        self.pretrained: Path = pretrained
        self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model: STModel = STModel().to(self.device)
        self.model.load_state_dict(torch.load(self.pretrained, map_location=self.device))
        self.model.eval()

        self.transform: transforms.Compose = transforms.Compose([
            transforms.Resize((150, 120)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust normalization for single-channel input
        ])

    def process_image(self, input_path: Union[str, Path], output_path: Union[str, Path]) -> None:
        """
        Process a single image using the pre-trained model.

        Args:
            input_path (Union[str, Path]): Path to the input image.
            output_path (Union[str, Path]): Path where the output image will be saved.
        """
        input_path: Path = Path(input_path)
        output_path: Path = Path(output_path)
        image: Image.Image = Image.open(input_path).convert('L')  # Ensure the image is in greyscale
        print(f"Loaded image: {input_path}")
        print(f"Image size: {image.size}")
        print(f"Image mode: {image.mode}")

        input_tensor: torch.Tensor = self.transform(image).unsqueeze(0).to(self.device)
        print(f"Transformed tensor shape: {input_tensor.shape}")
        print(f"Transformed tensor min, max: {input_tensor.min().item()}, {input_tensor.max().item()}")

        with torch.no_grad():
            output_tensor: torch.Tensor = self.model(input_tensor)
        print(f"Output tensor shape: {output_tensor.shape}")
        print(f"Output tensor min, max: {output_tensor.min().item()}, {output_tensor.max().item()}")

        output_array = np.array([output_tensor.squeeze().cpu().detach()])

        print(f"Processed and saved output image: {output_path}")
        print(f"Output image content: {output_array}")
        print(f"Output tensor shape: {output_tensor.shape}")

if __name__ == "__main__":
    stn: STN = STN(pretrained="spt_model.pt")
    stn.process_image(
        input_path=Path("dataset/train/aaAGoBxqnJgoEGzD.jpg"),
        output_path=Path("output.jpg")
    )

When I run the code, the output_tensor has min and max values of (0.0, 0.0), and when I print out the output_array, it's naturally also all zeros. Here is a sample of the printed output:

Loaded image: dataset/train/aaAGoBxqnJgoEGzD.jpg
Image size: (150, 120)
Image mode: L
Transformed tensor shape: torch.Size([1, 1, 150, 120])
Transformed tensor min, max: -2.8582..., 1.5927...
Output tensor shape: torch.Size([1, 1, 150, 120])
Output tensor min, max: 0.0, 0.0
Processed and saved output image: output.jpg
Output image content: [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
Output tensor shape: torch.Size([1, 1, 150, 120])

I suspect there might be an issue with how the image is being transformed or how the model is processing the image. Any insights into why this might be happening and how I can fix it would be greatly appreciated.

You can find the complete project code on my GitHub for more context: here.

Answer

This might be a naive answer, but just looked at the link to your code. Firstly, I see you init your localization weights for fc2 as zero here in model.py. That your feed-forward prediction gives you all zeros suggests to me that you're using an initialized instance of STN, not a trained instance.

I see in your code above, you load a pretrained model from disc. I don't see where in your model class that you've defined behavior for a pretrained argument, and I don't think nn.Module takes a pretrained path argument in the way you've defined it, but it does take kwargs. Since it doesn't know what to do with the pretrained argument, here it is just ignoring it. By the time you call process image, you're doing so with a newly initialized object.

Consider loading your STN model's state_dict. I think this would work:

if __name__ == "__main__":
    stn: STN = STN()
    
    save_path = 'spt_model.pt'
    state_dict = torch.load(save_path)
    
    STN.load_state_dict(state_dict)
    ...

Also note that in your main.py arguments, you have a possible typo on line 17. "spt_mode.pt". Hopefully you haven't accidentally saved the model under the wrong name.

© 2024 Dagalaxy. All rights reserved.