diff --git a/app/services/auth.py b/app/services/auth.py index adeb210..add7d9c 100644 --- a/app/services/auth.py +++ b/app/services/auth.py @@ -48,28 +48,33 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None): encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt -def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): - with Session(engine) as db: - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"} - ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username = payload.get("sub") - if username is None: - raise credentials_exception - token_data = TokenData(username=username) - except InvalidTokenError: +def get_current_user( + token: Annotated[str, Depends(oauth2_scheme)], + db: Session = Depends(get_session), + ): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"} + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username = payload.get("sub") + if username is None: raise credentials_exception - user = get_user(db, username=token_data.username) - if user is None: - raise credentials_exception - return user + token_data = TokenData(username=username) + except InvalidTokenError: + raise credentials_exception + user = get_user(db, username=token_data.username) + if user is None: + raise credentials_exception + return user -def auth_is_admin(token: str = Depends(oauth2_scheme)): - user = get_current_user(token=token) +def auth_is_admin( + token: str = Depends(oauth2_scheme), + db: Session = Depends(get_session), + ): + user = get_current_user(token=token, db=db) if not user.is_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/test/test_services/test_auth.py b/test/test_services/test_auth.py index 9b490b8..2990134 100644 --- a/test/test_services/test_auth.py +++ b/test/test_services/test_auth.py @@ -117,7 +117,7 @@ def test_auth_is_admin(db_session, admin_user, regular_user): admin_token = create_access_token(data={"sub": admin_user.name}) # Admin should pass - result = auth_is_admin(token=admin_token) + result = auth_is_admin(token=admin_token, db=db_session) assert result is True # Create token for regular user @@ -125,7 +125,7 @@ def test_auth_is_admin(db_session, admin_user, regular_user): # Regular user should fail with pytest.raises(HTTPException) as exc_info: - auth_is_admin(token=user_token) + auth_is_admin(token=user_token, db=db_session) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN