|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +from datetime import datetime, timedelta |
| 4 | +from typing import Any, Optional, Union |
| 5 | + |
| 6 | +from fastapi import Depends, HTTPException, status, Request |
| 7 | +from fastapi.security import OAuth2PasswordBearer |
| 8 | +from jose import jwt |
| 9 | +from passlib.context import CryptContext |
| 10 | +from pydantic import ValidationError |
| 11 | +from sqlalchemy.orm import Session |
| 12 | + |
| 13 | +from backend.app.core.conf import settings |
| 14 | +from backend.app.crud import user_crud |
| 15 | +from backend.app.datebase.db_mysql import get_db |
| 16 | +from backend.app.model import User |
| 17 | + |
| 18 | +pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto') # 密码加密 |
| 19 | + |
| 20 | +oauth2_schema = OAuth2PasswordBearer(tokenUrl='/v1/login') # 指明客户端请求token的地址 |
| 21 | + |
| 22 | +headers = {"WWW-Authenticate": "Bearer"} # 异常返回规范 |
| 23 | + |
| 24 | + |
| 25 | +def get_hash_password(password: str) -> str: |
| 26 | + """使用hash算法加密密码 """ |
| 27 | + return pwd_context.hash(password) |
| 28 | + |
| 29 | + |
| 30 | +def verity_password(plain_password: str, hashed_password: str) -> bool: |
| 31 | + """ |
| 32 | + 密码校验 |
| 33 | + :param plain_password: 要验证的密码 |
| 34 | + :param hashed_password: 要比较的hash密码 |
| 35 | + :return: 比较密码之后的结果 |
| 36 | + """ |
| 37 | + return pwd_context.verify(plain_password, hashed_password) |
| 38 | + |
| 39 | + |
| 40 | +def create_access_token(data: Union[int, Any], expires_delta: Optional[timedelta] = None) -> str: |
| 41 | + """ |
| 42 | + 生成加密 token |
| 43 | + :param data: 传进来的值 |
| 44 | + :param expires_delta: 增加的到期时间 |
| 45 | + :return: 加密token |
| 46 | + """ |
| 47 | + if expires_delta: |
| 48 | + expires = datetime.utcnow() + expires_delta |
| 49 | + else: |
| 50 | + expires = datetime.utcnow() + timedelta(settings.ACCESS_TOKEN_EXPIRE_MINUTES) |
| 51 | + to_encode = {"exp": expires, "sub": str(data)} |
| 52 | + encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, settings.ALGORITHM) |
| 53 | + return encoded_jwt |
| 54 | + |
| 55 | + |
| 56 | +async def get_current_user(db: Session = Depends(get_db), token: str = Depends(oauth2_schema)) -> User: |
| 57 | + """ |
| 58 | + 通过token获取当前用户 |
| 59 | + :param db: |
| 60 | + :param token: |
| 61 | + :return: |
| 62 | + """ |
| 63 | + credentials_exception = HTTPException( |
| 64 | + status.HTTP_401_UNAUTHORIZED, |
| 65 | + detail="无法验证凭据", |
| 66 | + headers={"WWW-Authenticate": "Bearer"}, |
| 67 | + ) |
| 68 | + try: |
| 69 | + # 解密token |
| 70 | + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) |
| 71 | + user_id = payload.get('sub') |
| 72 | + if not user_id: |
| 73 | + raise credentials_exception |
| 74 | + except (jwt.JWTError, ValidationError): |
| 75 | + raise credentials_exception |
| 76 | + user = user_crud.get_user_by_id(db, user_id) |
| 77 | + return user |
| 78 | + |
| 79 | + |
| 80 | +async def get_current_is_superuser(user: User = Depends(get_current_user)) -> bool: |
| 81 | + """ |
| 82 | + 通过token验证当前用户权限 |
| 83 | + :param user: |
| 84 | + :return: |
| 85 | + """ |
| 86 | + is_superuser = user.is_superuser |
| 87 | + if not is_superuser: |
| 88 | + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='用户权限不足', headers=headers) |
| 89 | + return is_superuser |
0 commit comments