How to write unit test for session_factory in fastapi

95 views Asked by At

I have this code:

src/core/container.py

from repositories import UserRepository

from services import UserService
from core.database import Database
from dependency_injector import containers, providers
from sqlalchemy.engine import URL
from configs import config


class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration(modules=["apis.user"])
)

db_url = URL.create(
drivername=config.DRIVER_NAME,
username=config.DB_USER,
password=config.DB_PASSWORD,
host=config.DB_HOST,
database=config.DB_NAME,
port=config.DB_PORT
)

db = providers.Singleton(Database, db_url=db_url)


user_repository = providers.Factory(UserRepository, session_factory=db.provided.session)

user_service = providers.Factory(UserService, user_repository=user_repository)

src/core/database.py

from contextlib import AbstractContextManager, contextmanager
from typing import Any, Callable

from sqlalchemy import create_engine, orm
from sqlalchemy.ext.declarative import as_declarative, declared_attr
from sqlalchemy.orm import Session
from configs import config


@as_declarative()
class BaseModel:
id: Any
_name_: str

# Generate _tablename_ automatically
@declared_attr
def __tablename__(cls) -> str:
  return cls.__name__.lower()


class Database:
def __init__(self, db_url: str) -> None:
self._engine = create_engine(db_url, echo=False)
self._session_factory = orm.scoped_session(
orm.sessionmaker(
  autocommit=False,
  autoflush=False,
  bind=self._engine,
),
)

def create_database(self) -> None:
BaseModel.metadata.create_all(self._engine)

@contextmanager
def session(self) -> Callable[..., AbstractContextManager[Session]]:
session: Session = self._session_factory()
session.expire_on_commit = False
try:
  yield session
except:
  session.rollback()
  raise
finally:
  session.close()

As I research online, best practice for unit test is something like we will create another database for testing purpose, and we will do the unit test there so it won't affect to the real database.

Also, as almost every example online that I can found, it always say that I need to override the get_db() function so it will override the database connection and connect to the test database instead. But in my current code, look like my team using session_factory instead of get_db() approach.

So here's it, i'm trying to write unit test for it by override the "Session":

src/conf_test_db.py

from fastapi import Depends
from sqlalchemy import create_engine, orm
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy_utils import create_database, drop_database
from contextlib import AbstractContextManager, contextmanager
from typing import Callable
from configs import config
from models.base import Base
from server import app
from core.database import Database

DATABASE_USERNAME = config.DB_USER
DATABASE_PASSWORD = config.DB_PASSWORD
DATABASE_HOST = config.DB_HOST
DATABASE_PORT = config.DB_PORT
DATABASE_NAME = config.DB_TEST_NAME

# set test database
SQLALCHEMY_DATABASE_URL = f"postgresql://" \
                          f"{DATABASE_USERNAME}" \
                          f":{DATABASE_PASSWORD}" \
                          f"@{DATABASE_HOST}" \
                          f":{DATABASE_PORT}" \
                          f"/{DATABASE_NAME}"

# engine = create_engine(SQLALCHEMY_DATABASE_URL)
# TestingSessionLocal = orm.scoped_session(
#     sessionmaker(autocommit=False, autoflush=False, bind=engine)
# )
#
# Base.metadata.drop_all(bind=engine)
# Base.metadata.create_all(bind=engine)
#
#
# def override_session():
#     session: Session = TestingSessionLocal
#     session.expire_on_commit = False
#     try:
#         yield session
#     except Exception:
#         session.rollback()
#         raise
#     finally:
#         session.close()
#
#
# # override db session
# app.dependency_overrides[Database.session] = override_session


class Database:
    def __init__(self, db_url: str) -> None:
        self._engine = create_engine(db_url, echo=False)
        self._session_factory = orm.scoped_session(
            orm.sessionmaker(
                autocommit=False,
                autoflush=False,
                bind=self._engine,
            ),
        )

    def create_database(self) -> None:
        Base.metadata.create_all(self._engine)

    def create_test_database(self) -> None:
        # Drop the existing test database if it exists
        drop_database(SQLALCHEMY_DATABASE_URL)

        # Create a new test database
        create_database(SQLALCHEMY_DATABASE_URL)

        # Create the tables in the test database
        Base.metadata.create_all(self._engine)

    @contextmanager
    def session(self) -> Callable[..., AbstractContextManager[Session]]:
        session: Session = self._session_factory()
        session.expire_on_commit = False
        try:
            yield session
        except:
            session.rollback()
            raise
        finally:
            session.close()


def setup_test_database():
    """
    Create a new test database, and drop it when the test is completed.
    """
    database = Database(SQLALCHEMY_DATABASE_URL)
    database.create_test_database()

    # Yield the session object to the test
    session = database.session()
    try:
        yield session
    finally:
        session.close()
        # drop_database(SQLALCHEMY_DATABASE_URL)


app.dependency_overrides[Session] = setup_test_database

src/tests/test_user.py

from faker import Faker
from fastapi.testclient import TestClient

from ..conf_test_db import app

client = TestClient(app)
fake = Faker()


def test_register_user():
    # Define the input data
    data = {
        "email": fake.email(),
        "password": fake.password(length=8, upper_case=True, special_chars=True),
        "first_name": fake.first_name(),
        "last_name": fake.last_name()
    }

    # Make a request to the API endpoint
    response = client.post("/v1/users/register", json=data)

    # Check the response status code
    assert response.status_code == 200

    # Check the response content
    assert "access_token" in response.json()
    assert "refresh_token" in response.json()
    assert "user_id" in response.json()
    assert "user_email" in response.json()
    assert "role" in response.json()
    assert "avatar_url" in response.json()
    assert response.json()["user_email"] == data["email"]

The unit test working fine, every test case pass, and the user are added to the database successfully.

So what's wrong here is that the database used for testing is still the real database, not the test database.

May I know what I've been doing wrong here?

Thanks.

0

There are 0 answers