Hello guys i'm implementing an avl tree and trying to self balance the tree but the root node does not balance. It only balances the left and right subtree. when i input 50 40 30 10, the root note should be 40 but instead it remains 50 which is wrong. Please could someone take a look at the code to what i did wrong. Thanks
Below is my code:
public void insert(T data){
    AVLNode<T> newNode = new AVLNode<T>(data);
    if(isEmpty()){
        root = newNode;
    }
    else{
        insert(root, newNode);
    }
}
private void insert(AVLNode<T> root, AVLNode<T> newNode){
 if(newNode.getData().compareTo(root.getData())<0){
        if(root.getLeftChild()!=null){
        AVLNode<T> leftNodes = root.getLeftChild();
        insert(leftNodes, newNode);
        root.setLeftChild(rebalance(leftNodes));
        }
        else{
            root.setLeftChild(newNode);
        }
    }
    else if(newNode.getData().compareTo(root.getData())>0){
        if(root.getRightChild()!=null){
        AVLNode<T> rightNodes = root.getRightChild();
        insert(rightNodes, newNode);
        root.setRightChild(rebalance(rightNodes));
    }
        else{
            root.setRightChild(newNode);
        }
    }
    else{
          root.setData(newNode.getData());
    }
    updateHeight(root); 
}
//re-balances the tree.
private AVLNode<T> rebalance(AVLNode<T> root){
    int difference = balance(root);
    if (difference > 1){
        if(balance(root.getLeftChild())>0){
            root = rotateRight(root);
        }
        else{
            root = rotateLeftRight(root);
        }
    }
    else if(difference < -1){
        if(balance(root.getRightChild())<0){
            root = rotateLeft(root);
        }
        else{
            root = rotateRightLeft(root);
        }
    }
    return root;
}
//updates the height of the tree.
public void updateHeight(AVLNode<T> root){
    if((root.getLeftChild()==null) && (root.getRightChild()!=null)){
        root.setHeight(root.getRightChild().getHeight()+1);
    }
    else if((root.getLeftChild() !=null)&&(root.getRightChild()==null)){
        root.setHeight(root.getLeftChild().getHeight()+1);
    }
    else
        root.setHeight(Math.max(getHeight(root.getLeftChild()), getHeight(root.getRightChild()))+1);
}
private int balance(AVLNode<T> root) {
    return getHeight(root.getLeftChild())-getHeight(root.getRightChild());
}
//single left left rotation of the tree
private AVLNode<T> rotateLeft(AVLNode<T> root){
    AVLNode<T> NodeA = root.getRightChild();
    root.setRightChild(NodeA.getLeftChild());
    NodeA.setLeftChild(root);
    root.setHeight(Math.max(getHeight(root.getLeftChild()), getHeight(root.getRightChild()))+1); //updates the height
    NodeA.setHeight(Math.max(getHeight(NodeA.getLeftChild()), getHeight(NodeA.getRightChild()))+1);
    return NodeA;
}
//single right right rotation of the tree
private AVLNode<T> rotateRight(AVLNode<T> root){
    AVLNode<T> NodeA = root.getLeftChild();
    root.setLeftChild(NodeA.getRightChild());
    NodeA.setRightChild(root);
    root.setHeight(Math.max(getHeight(root.getLeftChild()), getHeight(root.getRightChild()))+1); //updates the height of the AVL tree
    NodeA.setHeight(Math.max(getHeight(NodeA.getLeftChild()), getHeight(NodeA.getRightChild()))+1);
    return NodeA;
}
//a double rotation. Left right rotation
private AVLNode<T> rotateLeftRight(AVLNode<T> root){
    AVLNode<T> nodeA = root.getLeftChild();
    root.setLeftChild(rotateLeft(nodeA));
    return rotateRight(root);
}
//a right left rotation
private AVLNode<T> rotateRightLeft(AVLNode<T> root){
    AVLNode<T> nodeA = root.getRightChild();
    root.setRightChild(rotateRight(nodeA));
    return rotateLeft(root);
}