The current splitDim function only operates on tensors that are split evenly which isn't always the case, e.g. a QKV tensor. This change allows the function to be used for arbitrary splits