import torch
import torch.nn.functional as F
# Initialize the tensor
bs, nc, h, w = 1, 3, 256, 256
image_tensor = torch.randn(bs, nc, h, w)
# Define the floating point coordinates as tensors
x = torch.tensor([5.6, 20.3, 150.8])
y = torch.tensor([128.4, 50.7, 200.9])
# Normalize the coordinates to the range [-1, 1]
# The coordinates should be in the range [-1, 1] for grid_sample
x_normalized = (x / (w - 1)) * 2 - 1
y_normalized = (y / (h - 1)) * 2 - 1
# Stack and reshape the coordinates to create a grid tensor
# The shape should be (1, 1, N, 2) where N is the number of points
grid = torch.stack((x_normalized, y_normalized), dim=-1)
grid = grid.unsqueeze(0).unsqueeze(0)
# Sample the points using grid_sample
# grid_sample expects a 4D tensor of shape (N, C, H, W)
interpolated_pixel_values = F.grid_sample(image_tensor, grid, mode='bilinear', align_corners=True)
# The result will be a tensor of shape (1, 3, 1, N)
# We can squeeze it to get rid of the extra dimensions
interpolated_pixel_values = interpolated_pixel_values.squeeze().permute(1, 0)
print(interpolated_pixel_values)