Puzzler 1

The split_stack function listed below takes a 1D tensor that logically represents a set of independent features. It decomposes (“splits”) the 1D tensor into the individual features, performs several common ML operations on them, and then reassembles (“stacks”) the individual results into a single 1D tensor.

def split_stack(input_data, batch_size, input_dim):
    split_input_data = torch.split(input_data, input_dim * batch_size)
    layer_norm = torch.nn.LayerNorm(input_dim, eps=0.0, elementwise_affine=False, device='cuda:0')
    norm_features = []

    for i in range(len(split_input_data)):
        features = split_input_data[i]
        features_view = features.view([batch_size, input_dim])
        norm_features_view = layer_norm(features_view)
        tanh_norm_features = torch.tanh(norm_features_view)
        norm_features.append(tanh_norm_features)

    linear_input = torch.stack(norm_features)
    linear_layer = torch.nn.Linear(input_dim, 1)
    linear_output = linear_layer(linear_input)
    return nn.functional.relu(linear_output)

num_inputs, batch_size, input_dim = 16, 1024, 256
input_data = torch.rand(num_inputs * batch_size * input_dim, device='cuda:0')

The following code is functionally identical to the code above, but it operates directly on the 3D tensor, i.e., avoids the calls to split and stack.

def combined(input_data, num_inputs, batch_size, input_dim):
    features3d = input_data.view(num_inputs, batch_size, input_dim)
    layer_norm = torch.nn.LayerNorm(input_dim, eps=0.0, elementwise_affine=False, device='cuda:0')
    norm_features = layer_norm(features3d)

    tanh_norm_features = torch.tanh(norm_features).view(num_inputs, batch_size, input_dim)
    linear_layer = torch.nn.Linear(input_dim, 1)
    linear_output = linear_layer(tanh_norm_features)
    return nn.functional.relu(linear_output)

num_inputs, batch_size, input_dim = 16, 1024, 256
input_data = torch.rand(num_inputs * batch_size * input_dim, device='cuda:0')

Empirically, combined is much faster than split_stack. As a concrete instance, with 16 input features, a batch size of 1024, and input dimension of 256, combined is 8.4 times faster than split_stack in terms of start-to-finish time on a 40 GB A100 GPU. Why?

Puzzler 2

We use PyTorch 2 to create compiled versions of the functions above:

compiled_split_stack = torch.compile(split_stack)
compiled_combined = torch.compile(combined)

Empirically, compiled_split_stack is 2.5 times faster and compiled_combined is 10.3 times faster than split_stack in terms of start-to-finish time on a 40GB A100 GPU. Why?

PyTorch Profiler trace available here.

See answer and discussion