trouble backpropagating through a very complicated function in pytorch - no way to avoid inplace operations

42 views Asked by At

I want to define a loss function based on a complex series of transformations of the output of a neural network. These transformations involve somewhat complex logic like this that doesn't seem to be possible without in-place operations (see comments):

def get_X_torch(C, c_table):
    """
    _zmat_transformation.py line 57
    C = torch tensor of floats where rows are all bonds, then all angles, then all dihedrals
    c_table = torch tensor of ints where rows are all bond_idx, then all angle_idx, then all dihedral_idx
    c_table blank indices for beginning of z-matrix are labeled as -9223372036854775807
    """
    X = torch.zeros_like(C, device="cuda:0")  # ([b a d], n_atoms)
    n_atoms = X.shape[1]

    # this is some complicated logic - not vectorizable because the variables
    # all influence each other throughout the loop (it's a nonlinear transformation)
    j: int = 0
    for j in range(n_atoms):
        B, ref_pos = get_B_torch(X, c_table, j)
        S = get_S_torch(C, j)
        X[:, j] = torch.mv(B, S) + get_ref_pos_torch(X, c_table[0, j])    # X[:, j] depends on X's current value as a whole!!! This is the tricky step
    return X.T

The training code snippet looks like this below. I need to build up clash_loss iteratively with my_function_script, which is a wrapper for the functionality above, but since I'm not doing clash_loss += that should be fine. I think the problem is in the complicated logic above. The error message is that it can't take the gradient because of in-place operations somewhere in the pathway.


           reconstructed_angles = torch.atan2(internal_data_batch_reconstructed[:, 0:304], internal_data_batch_reconstructed[:, 304:])
            if clash_mode is True:
                clash_loss = torch.tensor(0.0, requires_grad=True, device="cuda")
                bonds = torch.tensor(init_z_mat["bond"].values, device="cuda", requires_grad=True)
                angles = torch.tensor(init_z_mat["angle"].values * (torch.pi / 180), device="cuda", requires_grad=True)

                for i in range(batch_size):
                    print(i + 1, batch_size)
                    C = torch.stack((bonds, angles, reconstructed_angles[i]))
                    xyz = my_function_script(C, construction_table)  # very complicated function but written in pure PyTorch
                    # this function cannot not involve inplace operations (see other bit of code)
                    temp_loss = 1.0 if get_clash_loss(xyz) > 0.0 else 0.0
                    clash_loss = clash_loss + temp_loss
                total_loss = clash_loss
                total_loss.backward()  # <--- this fails

Is there anything I can do to make this train of logic differentiable so that clash_loss.backward() works? Manually finding the derivative is completely impossible for such a complex set of functions...

I tried rewriting with copies and without obvious inplace edits (see below) but this still doesn't work.

Xs = [X]
for j in range(n_atoms):
    B, ref_pos = get_B_torch(Xs[-1], c_table, j)
    S = get_S_torch(C, j)
    first = torch.mv(B, S)
    second = get_ref_pos_torch(Xs[-1], c_table[0, j])
    Xcopy = torch.cat((Xs[-1][:, 0:j - 1], (first + second).reshape((-1, 1)), Xs[-1][:, j + 1:]), -1)
    Xs = Xs + [Xcopy] 

return Xs[-1].T
0

There are 0 answers