diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7b88996..98ab127 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,4 +43,4 @@ jobs: python manage.py migrate - name: 테스트 실행 run: | - python manage.py test + pytest diff --git a/accounts/admin.py b/accounts/admin.py index d547620..71aba23 100644 --- a/accounts/admin.py +++ b/accounts/admin.py @@ -1,5 +1,97 @@ from django.contrib import admin +from django.contrib.auth.admin import UserAdmin from .models import CustomUser -admin.site.register(CustomUser) + +@admin.register(CustomUser) +class CustomUserAdmin(UserAdmin): + """ + 관리자로 하여금 커스텀 사용자 모델을 관리하기 위해 제공되는 관리자 클래스입니다. + - 모델로 CustomUser를 사용합니다. + - 커스텀 사용자 수정 및 생성 폼을 제공합니다. + - 커스텀 사용자 목록을 조회할 수 있으며, 필터링 및 검색 옵션을 제공합니다. + - 커스텀 사용자 목록에 표시될 필드들과 정렬 기준을 제공합니다. + - 모든 유형의 커스텀 사용자을 생성하는 기능을 제공합니다. + """ + + model = CustomUser + + readonly_fields = ("created_at", "updated_at") + + fieldsets = ( + ("Login info", {"fields": ("email", "nickname", "password")}), + ("Personal info", {"fields": ("first_name", "last_name", "introduction")}), + ( + "Permissions", + { + "fields": ( + "is_active", + "is_staff", + "is_superuser", + "groups", + "user_permissions", + ), + "classes": ("wide",), + }, + ), + ( + "Important dates", + { + "fields": ("last_login", "created_at", "updated_at"), + "classes": ("wide",), + }, + ), + ) + + add_fieldsets = ( + ( + "Register info", + { + "fields": ( + "email", + "nickname", + "password1", + "password2", + "is_staff", + "is_superuser", + ), + }, + ), + ) + + list_display = ("email", "nickname", "is_staff", "is_superuser") + list_filter = ("is_staff", "is_superuser", "is_active") + search_fields = ("email", "nickname", "first_name", "last_name") + ordering = ("email", "created_at") + + def save_model(self, request, obj, form, change): + if not change: + if form.cleaned_data.get("is_staff") and not form.cleaned_data.get( + "is_superuser" + ): + CustomUser.objects.create_staff( + email=form.cleaned_data["email"], + password=form.cleaned_data["password1"], + nickname=form.cleaned_data["nickname"], + ) + elif form.cleaned_data.get("is_superuser"): + CustomUser.objects.create_superuser( + email=form.cleaned_data["email"], + password=form.cleaned_data["password1"], + nickname=form.cleaned_data["nickname"], + ) + else: + CustomUser.objects.create_user( + email=form.cleaned_data["email"], + password=form.cleaned_data["password1"], + nickname=form.cleaned_data["nickname"], + ) + else: + obj.email = form.cleaned_data.get("email", obj.email) + obj.nickname = form.cleaned_data.get("nickname", obj.nickname) + if "password1" in form.cleaned_data: + obj.set_password(form.cleaned_data["password1"]) + obj.is_staff = form.cleaned_data.get("is_staff", obj.is_staff) + obj.is_superuser = form.cleaned_data.get("is_superuser", obj.is_superuser) + obj.save() diff --git a/accounts/migrations/0004_remove_customuser_date_joined_and_more.py b/accounts/migrations/0004_remove_customuser_date_joined_and_more.py new file mode 100644 index 0000000..3646c7c --- /dev/null +++ b/accounts/migrations/0004_remove_customuser_date_joined_and_more.py @@ -0,0 +1,32 @@ +# Generated by Django 5.1.1 on 2024-10-04 06:42 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('accounts', '0003_customuser_created_at_customuser_introduction_and_more'), + ] + + operations = [ + migrations.RemoveField( + model_name='customuser', + name='date_joined', + ), + migrations.AlterField( + model_name='customuser', + name='introduction', + field=models.TextField(blank=True, max_length=20, null=True, verbose_name='자기소개'), + ), + migrations.AlterField( + model_name='customuser', + name='nickname', + field=models.CharField(max_length=20, unique=True, verbose_name='닉네임'), + ), + migrations.AlterField( + model_name='customuser', + name='phone_number', + field=models.CharField(blank=True, max_length=20, null=True, verbose_name='연락처'), + ), + ] diff --git a/accounts/migrations/0005_customuser_profile_image.py b/accounts/migrations/0005_customuser_profile_image.py new file mode 100644 index 0000000..8a43a82 --- /dev/null +++ b/accounts/migrations/0005_customuser_profile_image.py @@ -0,0 +1,18 @@ +# Generated by Django 5.1.1 on 2024-10-10 06:29 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('accounts', '0004_remove_customuser_date_joined_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='customuser', + name='profile_image', + field=models.ImageField(blank=True, default='profile_images/default.jpg', upload_to='profile_images/'), + ), + ] diff --git a/accounts/migrations/0006_remove_customuser_profile_image.py b/accounts/migrations/0006_remove_customuser_profile_image.py new file mode 100644 index 0000000..33fa2cd --- /dev/null +++ b/accounts/migrations/0006_remove_customuser_profile_image.py @@ -0,0 +1,17 @@ +# Generated by Django 5.1.1 on 2024-10-10 07:31 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('accounts', '0005_customuser_profile_image'), + ] + + operations = [ + migrations.RemoveField( + model_name='customuser', + name='profile_image', + ), + ] diff --git a/accounts/models.py b/accounts/models.py index a50681d..c308052 100644 --- a/accounts/models.py +++ b/accounts/models.py @@ -1,8 +1,4 @@ -from django.apps import apps -from django.contrib import auth -from django.contrib.auth.hashers import make_password from django.contrib.auth.models import AbstractUser, UserManager -from django.core.exceptions import ValidationError from django.db import models @@ -12,41 +8,51 @@ class CustomUserManager(UserManager): - CustomUser에서 설정변경한 사용한 식별자(email)로 사용자 인스턴스를 생성하도록 합니다. """ - def _create_user(self, email, password, **extra_fields): + def _create_user(self, email, password, nickname, **extra_fields): """ - 사용자 유형과 관계없이 실제로 사용자를 생성해 넘겨줍니다. + 사용자 유형과 관계없이 사용자를 실제로 생성하고 반환합니다. """ if not email: - raise ValueError("이메일 입력은 필수입니다.") + raise ValueError("이메일은 필수로 입력하셔야 합니다.") + if not nickname: + raise ValueError("닉네임은 필수로 입력하셔야 합니다.") email = self.normalize_email(email) - user = self.model(email=email, **extra_fields) + user = self.model(email=email, nickname=nickname, **extra_fields) user.set_password(password) user.save(using=self._db) return user - def create_user(self, email, password=None, **extra_fields): + def create_user(self, email, password, nickname, **extra_fields): """ - 일반 사용자를 생성합니다. + 학생(student)를 생성합니다. """ extra_fields.setdefault("is_staff", False) extra_fields.setdefault("is_superuser", False) - return self._create_user(email, password, **extra_fields) + return self._create_user(email, password, nickname, **extra_fields) - def create_superuser(self, email, password=None, **extra_fields): + def create_staff(self, email, password, nickname, **extra_fields): """ - 슈퍼 사용자를 생성합니다. + 스태프(tutor)를 생성합니다. """ extra_fields.setdefault("is_staff", True) - extra_fields.setdefault("is_superuser", True) - return self._create_user(email, password, **extra_fields) + extra_fields.setdefault("is_superuser", False) + + if extra_fields.get("is_staff") is not True: + raise ValueError("Staff user must have is_staff=True.") + return self._create_user(email, password, nickname, **extra_fields) - def create_staff_user(self, email, password=None, **extra_fields): + def create_superuser(self, email, password, nickname, **extra_fields): """ - 관리자(강사)를 생성합니다. + 관리자(superuser)를 생성합니다. """ extra_fields.setdefault("is_staff", True) - extra_fields.setdefault("is_superuser", False) - return self._create_user(email, password, **extra_fields) + extra_fields.setdefault("is_superuser", True) + + if extra_fields.get("is_staff") is not True: + raise ValueError("Superuser must have is_staff=True.") + if extra_fields.get("is_superuser") is not True: + raise ValueError("Superuser must have is_superuser=True.") + return self._create_user(email, password, nickname, **extra_fields) class CustomUser(AbstractUser): @@ -56,14 +62,20 @@ class CustomUser(AbstractUser): - 사용자 매니저 지정(UserManager -> CustomUserManager) """ + # 미사용할 기본 필드 username = None - email = models.EmailField( - unique=True, verbose_name="이메일" - ) # 필드를 기본키로 지정 + date_joined = None + + email = models.EmailField(unique=True, verbose_name="이메일") # 기본키 변경 - nickname = models.CharField(max_length=20, verbose_name="닉네임") - phone_number = models.CharField(max_length=20, verbose_name="연락처") - introduction = models.TextField(max_length=20, verbose_name="자기소개") + # 추가 필드 + nickname = models.CharField(max_length=20, unique=True, verbose_name="닉네임") + phone_number = models.CharField( + max_length=20, null=True, blank=True, verbose_name="연락처" + ) + introduction = models.TextField( + max_length=20, null=True, blank=True, verbose_name="자기소개" + ) created_at = models.DateTimeField(auto_now_add=True, verbose_name="생성 일자") updated_at = models.DateTimeField(auto_now=True, verbose_name="갱신 일자") @@ -77,15 +89,16 @@ class CustomUser(AbstractUser): verbose_name="수강 학생들", # 사용자에게 보일 이름 ) - USERNAME_FIELD = "email" # 인증시 사용할 필드 지정 - REQUIRED_FIELDS = [] + # 인증시 사용할 필드 + USERNAME_FIELD = "email" + REQUIRED_FIELDS = ["nickname"] # email, password 자동 포함 objects = CustomUserManager() def __str__(self): return self.email - def clean(self): - super().clean() - if not self.email: - raise ValidationError("이메일은 필수입니다.") + def get_image_url(self): + if getattr(self, "image", None): + return self.image.image_url + return "https://paullab.co.kr/images/weniv-licat.png" diff --git a/accounts/permissions.py b/accounts/permissions.py index d885cc9..2287656 100644 --- a/accounts/permissions.py +++ b/accounts/permissions.py @@ -1,33 +1,20 @@ from rest_framework import permissions -class BaseAuthPermission(permissions.IsAuthenticated): +class IsAuthenticatedAndActive(permissions.IsAuthenticated): """ 권한 값으로 True/False를 반환합니다. - - 사용자 객체 생성 및 인증 여부를 확인하고 인증자에게만 허용합니다. + - 사용자 객체 생성 및 인증 여부를 확인하고 + active한 인증자에게만 허용합니다. """ message = "이 작업을 수행할 권한이 없습니다." - -class IsAuthenticatedOrCreateOnly(BaseAuthPermission): - """ - 권한 값으로 True/False를 반환합니다. - - GET 요청을 인증된 사용자에게 허용합니다. - - POST 요청을 누구에게나 허용합니다. - """ - - message = "이 작업을 수행하려면 로그인이 필요합니다." - def has_permission(self, request, view): - if request.method == "GET": - return super().has_permission(request, view) - elif request.method == "POST": - return True - return False + return super().has_permission(request, view) and request.user.is_active -class IsTutor(BaseAuthPermission): +class IsTutor(IsAuthenticatedAndActive): """ 권한 값으로 True/False를 반환합니다. - 요청 유형과 관계없이 강사(Tutor)이면 권한을 허용합니다. @@ -37,7 +24,7 @@ def has_permission(self, request, view): return super().has_permission(request, view) and request.user.is_staff -class IsSuperUser(BaseAuthPermission): +class IsSuperUser(IsAuthenticatedAndActive): """ 권한 값으로 True/False를 반환합니다. - 요청 유형과 관계없이 관리자(superuser)이면 권한을 허용합니다. @@ -45,21 +32,3 @@ class IsSuperUser(BaseAuthPermission): def has_permission(self, request, view): return super().has_permission(request, view) and request.user.is_superuser - - -class IsTutorOrSuperUserOrSuperUserCreateOnly(BaseAuthPermission): - """ - 권한 값으로 True/False를 반환합니다. - - GET 요청을 강사(tutor) 또는 관리자(superuser)에게 허용합니다. - - POST 요청을 관리자(superuser)에게만 허용합니다. - """ - - def has_permission(self, request, view): - if not super().has_permission(request, view): - return False - - if request.method == "GET": - return request.user.is_staff or request.user.is_superuser - elif request.method == "POST": - return request.user.is_superuser - return False diff --git a/accounts/serializers.py b/accounts/serializers.py index 1b4aab2..d45905c 100644 --- a/accounts/serializers.py +++ b/accounts/serializers.py @@ -1,14 +1,126 @@ from django.contrib.auth.hashers import check_password +from materials.models import Image +from materials.serializers import ImageSerializer from rest_framework import serializers from .models import CustomUser -class CustomUserSerializer(serializers.ModelSerializer): +class UserRegistrationSerializer(serializers.ModelSerializer): + """ + 회원가입을 위한 시리얼라이저입니다. + """ + + confirm_password = serializers.CharField(write_only=True, required=True) + + class Meta: + model = CustomUser + fields = ["email", "nickname", "password", "confirm_password"] + extra_kwargs = { + "password": {"write_only": True}, + "confirm_password": {"write_only": True}, + } + + +class PasswordResetSerializer(serializers.Serializer): + """ + 비밀번호 재설정을 위한 시리얼라이저입니다. + - 현재 비밀번호와 새로운 비밀번호를 받아 비밀번호를 변경합니다. + """ + + current_password = serializers.CharField(write_only=True, required=True) + new_password = serializers.CharField(write_only=True, required=True) + confirm_new_password = serializers.CharField(write_only=True, required=True) + + def validate_new_password(self, value): + """ + 새로운 비밀번호의 복잡성을 검증합니다. + """ + if len(value) < 8: + raise serializers.ValidationError( + {"password": "비밀번호는 8 글자 이상이어야 합니다."} + ) + if not any(char.isdigit() for char in value): + raise serializers.ValidationError( + {"password": "1개 이상의 숫자를 포함해야 합니다. "} + ) + if not any(char.isupper() for char in value): + raise serializers.ValidationError( + {"password": "1개 이상의 대문자를 포함해야 합니다."} + ) + if not any(char in r"!@#$%^&*()-_=+[{]}\|;:'\",<.>/?`~" for char in value): + raise serializers.ValidationError( + {"password": "1개 이상의 특수 문자를 포함해야 합니다."} + ) + return value + + def validate(self, data): + """ + 새로운 비밀번호와 확인 비밀번호가 일치하는지 검증하고 반환합니다. + """ + new_password = data.get("new_password") + confirm_new_password = data.get("confirm_new_password") + if ( + new_password + and confirm_new_password + and new_password != confirm_new_password + ): + raise serializers.ValidationError("새 비밀번호가 일치하지 않습니다.") + + user = self.context["request"].user + if user.check_password(new_password): + raise serializers.ValidationError( + "이전과 동일한 비밀번호를 사용할 수 없습니다." + ) + return data + + def validate_current_password(self, value): + """ + 현재 비밀번호가 맞는지 검증하고 그 값을 반환합니다. + """ + user = self.context["request"].user + if not user.check_password(value): + raise serializers.ValidationError("기존 비밀번호가 올바르지 않습니다.") + return value + + def save(self, **kwargs): + """ + 비밀번호를 재설정합니다. + - 어떠한 데이터도 직렬화하여 반환하지 않는 대신 업데이트된 사용자 객체를 반환합니다. + """ + user = self.context["request"].user + new_password = self.validated_data["new_password"] + + user.set_password(new_password) + user.save() + return user + + +class StudentListSerializer(serializers.ModelSerializer): + """ + 학생 목록을 위한 시리얼라이저입니다. + """ + + class Meta: + model = CustomUser + fields = ["id", "email", "nickname", "created_at"] + read_only_fields = ["id", "email", "nickname", "created_at"] + + +class TutorListSerializer(serializers.ModelSerializer): + """ + 강사 목록을 위한 시리얼라이저입니다. + """ + + class Meta: + model = CustomUser + fields = ["id", "email", "nickname", "created_at"] + read_only_fields = ["id", "email", "nickname", "created_at"] + + +class CustomUserDetailSerializer(serializers.ModelSerializer): """ 커스텀 사용자의 시리얼라이저입니다. - - 데이터 직렬화를 담당합니다. - - 데이터 역직렬화를 담당합니다. """ confirm_password = serializers.CharField(write_only=True, required=True) @@ -17,16 +129,33 @@ class CustomUserSerializer(serializers.ModelSerializer): student_count = serializers.SerializerMethodField() tutor_count = serializers.SerializerMethodField() + profile_image = ImageSerializer(required=False) + class Meta: model = CustomUser - fields = "__all__" - exclude = ["is_superuser"] + fields = [ + "id", + "email", + "nickname", + "password", + "confirm_password", + "profile_image", + "is_active", + "is_staff", + "is_superuser", + "students", + "user_type", + "student_count", + "tutor_count", + "created_at", + "updated_at", + ] read_only_fields = [ "id", "is_active", "is_staff", - "date_joined", - "tutor", + "is_superuser", + "students", "user_type", "student_count", "tutor_count", @@ -67,6 +196,20 @@ def get_tutor_count(self, obj): return CustomUser.objects.filter(is_staff=True, is_superuser=False).count() return None + def validate_nickname(self, value): + """ + 닉네임 필드의 데이터를 검증합니다. + """ + if ( + CustomUser.objects.filter(nickname=value) + .exclude(id=self.instance.id if self.instance else None) + .exists() + ): + raise serializers.ValidationError( + {"nickname": "사용할 수 없는 닉네임입니다."} + ) + return value + def validate_email(self, value): """ 이메일 필드의 데이터를 검증합니다. @@ -95,12 +238,26 @@ def validate_password(self, value): raise serializers.ValidationError( {"password": "1개 이상의 대문자를 포함해야 합니다."} ) - if not any(char in "!@#$%^&*()-_=+[{]}\|;:'\",<.>/?`~" for char in value): + if not any(char in r"!@#$%^&*()-_=+[{]}\|;:'\",<.>/?`~" for char in value): raise serializers.ValidationError( {"password": "1개 이상의 특수 문자를 포함해야 합니다."} ) return value + def validate_nickname(self, value): + """ + 닉네임 필드의 데이터를 검증합니다. + """ + if ( + CustomUser.objects.filter(nickname=value) + .exclude(id=self.instance.id if self.instance else None) + .exists() + ): + raise serializers.ValidationError( + {"nickname": "이미 존재하는 닉네임입니다."} + ) + return value + def validate(self, data): """ 전체 필드에 대한 데이터를 검증합니다. @@ -110,9 +267,9 @@ def validate(self, data): password = data.get("password") confirm_password = data.get("confirm_password") + if password and confirm_password and password != confirm_password: raise serializers.ValidationError("비밀번호가 일치하지 않습니다.") - if ( self.instance and password @@ -121,7 +278,6 @@ def validate(self, data): raise serializers.ValidationError( {"password": "이전과 동일한 비밀번호를 사용할 수 없습니다."} ) - return data def create(self, validated_data): @@ -132,8 +288,12 @@ def create(self, validated_data): try: password = validated_data.pop("password", None) validated_data.pop("confirm_password") + profile_image = validated_data.pop("profile_image", None) + user = CustomUser.objects.create_user(password=password, **validated_data) + if profile_image: + Image.objects.create(user=user, **profile_image) return user except Exception as e: raise serializers.ValidationError("사용자 생성 중 오류가 발생했습니다.") @@ -146,11 +306,23 @@ def update(self, instance, validated_data): password = validated_data.pop("password", None) validated_data.pop("confirm_password", None) + profile_image = validated_data.pop("profile_image", None) + user = super().update(instance, validated_data) + # 업데이트에 비밀번호가 있는 경우 if password: user.set_password(password) user.save() + + # 업데이트에 프로필 이미지가 있는 경우 + if profile_image: + if hasattr(user, 'image'): + for attr, value in profile_image.items(): + setattr(user.image, attr, value) + user.image.save() + else: + Image.objects.create(user=user, **profile_image) return user except Exception as e: raise serializers.ValidationError("사용자 업데이트 중 오류가 발생했습니다.") @@ -176,77 +348,3 @@ def to_representation(self, instance): representation.pop("tutor", None) return representation - - -class PasswordResetSerializer(serializers.Serializer): - """ - 비밀번호 재설정을 위한 시리얼라이저입니다. - - 현재 비밀번호와 새로운 비밀번호를 받아 비밀번호를 변경합니다. - """ - - current_password = serializers.CharField(write_only=True, required=True) - new_password = serializers.CharField(write_only=True, required=True) - confirm_new_password = serializers.CharField(write_only=True, required=True) - - def validate_new_password(self, value): - """ - 새로운 비밀번호의 복잡성을 검증합니다. - """ - if len(value) < 8: - raise serializers.ValidationError( - {"password": "비밀번호는 8 글자 이상이어야 합니다."} - ) - if not any(char.isdigit() for char in value): - raise serializers.ValidationError( - {"password": "1개 이상의 숫자를 포함해야 합니다. "} - ) - if not any(char.isupper() for char in value): - raise serializers.ValidationError( - {"password": "1개 이상의 대문자를 포함해야 합니다."} - ) - if not any(char in "!@#$%^&*()-_=+[{]}\|;:'\",<.>/?`~" for char in value): - raise serializers.ValidationError( - {"password": "1개 이상의 특수 문자를 포함해야 합니다."} - ) - return value - - def validate(self, data): - """ - 새로운 비밀번호와 확인 비밀번호가 일치하는지 검증하고 반환합니다. - """ - new_password = data.get("new_password") - confirm_new_password = data.get("confirm_new_password") - if ( - new_password - and confirm_new_password - and new_password != confirm_new_password - ): - raise serializers.ValidationError("새 비밀번호가 일치하지 않습니다.") - - user = self.context["request"].user - if user.check_password(new_password): - raise serializers.ValidationError( - "이전과 동일한 비밀번호를 사용할 수 없습니다." - ) - return data - - def validate_current_password(self, value): - """ - 현재 비밀번호가 맞는지 검증하고 그 값을 반환합니다. - """ - user = self.context["request"].user - if not user.check_password(value): - raise serializers.ValidationError("현재 비밀번호가 올바르지 않습니다.") - return value - - def save(self, **kwargs): - """ - 비밀번호를 재설정합니다. - - 어떠한 데이터도 직렬화하여 반환하지 않는 대신 업데이트된 사용자 객체를 반환합니다. - """ - user = self.context["request"].user - new_password = self.validated_data["new_password"] - - user.set_password(new_password) - user.save() - return user diff --git a/accounts/test/test_accounts_admin.py b/accounts/test/test_accounts_admin.py new file mode 100644 index 0000000..98fb0fa --- /dev/null +++ b/accounts/test/test_accounts_admin.py @@ -0,0 +1,169 @@ +import pytest +from accounts.admin import CustomUserAdmin +from accounts.models import CustomUser +from django.contrib.admin.sites import AdminSite +from django.contrib.auth import get_user_model +from django.test import RequestFactory + +User = get_user_model() + + +@pytest.fixture +def admin_site(): + return AdminSite() + + +@pytest.fixture +def custom_user_admin(admin_site): + return CustomUserAdmin(CustomUser, admin_site) + + +@pytest.fixture +def request_factory(): + return RequestFactory() + + +@pytest.fixture +def admin_user(): + return User.objects.create_superuser("admin@example.com", "adminpassword", "admin") + + +@pytest.mark.django_db +class TestCustomUserAdmin: + # Given: CustomUserAdmin이 설정되어 있을 때 + # When: list_display 속성을 확인하면 + # Then: 지정된 필드들이 포함되어 있어야 합니다. + def test_관리자인터페이스_list_display_확인(self, custom_user_admin): + assert custom_user_admin.list_display == ( + "email", + "nickname", + "is_staff", + "is_superuser", + ) + + # Given: CustomUserAdmin이 설정되어 있을 때 + # When: ordering 속성을 확인하면 + # Then: 이메일로 정렬되어야 합니다. + def test_관리자인터페이스_list_정렬_기준_확인(self, custom_user_admin): + assert custom_user_admin.ordering == ( + "email", + "created_at", + ) + + # Given: CustomUserAdmin이 설정되어 있을 때 + # When: add_fieldsets 속성을 확인하면 + # Then: 지정된 필드들이 포함되어 있어야 합니다. + def test_관리자인터페이스_add_fieldsets_속성_확인(self, custom_user_admin): + fieldsets = custom_user_admin.add_fieldsets[0][1]["fields"] + assert "email" in fieldsets + assert "nickname" in fieldsets + assert "password1" in fieldsets + assert "password2" in fieldsets + assert "is_staff" in fieldsets + assert "is_superuser" in fieldsets + + # Given: 관리자 사용자와 CustomUserAdmin이 설정되어 있을 때 + # When: 일반 사용자를 생성하면 + # Then: create_user 메서드가 호출되어야 합니다. + def test_관리자인터페이스_save_model_사용자_생성( + self, custom_user_admin, request_factory, admin_user + ): + request = request_factory.post("/admin/accounts/customuser/add/") + request.user = admin_user + obj = CustomUser(email="user@example.com", nickname="user") + form = type( + "Form", + (object,), + { + "cleaned_data": { + "email": "user@example.com", + "password1": "userpassword", + "nickname": "user", + "is_staff": False, + "is_superuser": False, + } + }, + )() + custom_user_admin.save_model(request, obj, form, False) + assert CustomUser.objects.filter( + email="user@example.com", is_staff=False, is_superuser=False + ).exists() + + # Given: 관리자 사용자와 CustomUserAdmin이 설정되어 있을 때 + # When: 스태프 사용자를 생성하면 + # Then: create_staff 메서드가 호출되어야 합니다. + def test_관리자인터페이스_save_model_staff_생성( + self, custom_user_admin, request_factory, admin_user + ): + request = request_factory.post("/admin/accounts/customuser/add/") + request.user = admin_user + obj = CustomUser(email="staff@example.com", nickname="staff") + form = type( + "Form", + (object,), + { + "cleaned_data": { + "email": "staff@example.com", + "password1": "staffpassword", + "nickname": "staff", + "is_staff": True, + "is_superuser": False, + } + }, + )() + custom_user_admin.save_model(request, obj, form, False) + assert CustomUser.objects.filter( + email="staff@example.com", is_staff=True, is_superuser=False + ).exists() + + # Given: 관리자 사용자와 CustomUserAdmin이 설정되어 있을 때 + # When: 슈퍼유저를 생성하면 + # Then: create_superuser 메서드가 호출되어야 합니다. + def test_관리자인터페이스_save_model_superuser_생성( + self, custom_user_admin, request_factory, admin_user + ): + request = request_factory.post("/admin/accounts/customuser/add/") + request.user = admin_user + obj = CustomUser(email="super@example.com", nickname="super") + form = type( + "Form", + (object,), + { + "cleaned_data": { + "email": "super@example.com", + "password1": "superpassword", + "nickname": "super", + "is_staff": True, + "is_superuser": True, + } + }, + )() + custom_user_admin.save_model(request, obj, form, False) + assert CustomUser.objects.filter( + email="super@example.com", is_staff=True, is_superuser=True + ).exists() + + # Given: 관리자 사용자와 CustomUserAdmin이 설정되어 있고 기존 사용자가 있을 때 + # When: 기존 사용자를 수정하면 + # Then: 기본 save_model 메서드가 호출되어야 합니다. + def test_관리자인터페이스_save_model_사용자_갱신( + self, custom_user_admin, request_factory, admin_user + ): + existing_user = CustomUser.objects.create_user( + "existing@example.com", "password", "existing" + ) + request = request_factory.post("/admin/accounts/customuser/1/change/") + request.user = admin_user + form = type( + "Form", + (object,), + { + "cleaned_data": { + "email": "existing@example.com", + "nickname": "updated_nickname", + } + }, + )() + custom_user_admin.save_model(request, existing_user, form, True) + existing_user.refresh_from_db() + assert existing_user.nickname == "updated_nickname" diff --git a/accounts/test/test_accounts_models.py b/accounts/test/test_accounts_models.py new file mode 100644 index 0000000..19ef141 --- /dev/null +++ b/accounts/test/test_accounts_models.py @@ -0,0 +1,136 @@ +import pytest +from django.contrib.auth import get_user_model +from django.core.exceptions import ValidationError + +User = get_user_model() + + +@pytest.mark.django_db +class TestCustomUser: + + # Given: CustomUser 모델이 존재할 때 + # When: 일반 사용자를 생성하면 + # Then: 사용자가 정상적으로 생성되어야 합니다. + def test_user_생성(self): + user = User.objects.create_user( + email="test@example.com", password="testpassword", nickname="testnick" + ) + assert user.email == "test@example.com" + assert user.nickname == "testnick" + assert not user.is_staff + assert not user.is_superuser + + # Given: CustomUser 모델이 존재할 때 + # When: 스태프 사용자를 생성하면 + # Then: 스태프 권한을 가진 사용자가 정상적으로 생성되어야 합니다. + def test_staff_생성(self): + staff = User.objects.create_staff( + email="staff@example.com", password="staffpassword", nickname="staffnick" + ) + assert staff.email == "staff@example.com" + assert staff.nickname == "staffnick" + assert staff.is_staff + assert not staff.is_superuser + + # Given: CustomUser 모델이 존재할 때 + # When: 슈퍼유저를 생성하면 + # Then: 관리자 권한을 가진 사용자가 정상적으로 생성되어야 한다. + def test_superuser_생성(self): + admin = User.objects.create_superuser( + email="admin@example.com", password="adminpassword", nickname="adminnick" + ) + assert admin.email == "admin@example.com" + assert admin.nickname == "adminnick" + assert admin.is_staff + assert admin.is_superuser + + # Given: CustomUser 모델이 존재할 때 + # When: 이메일 없이 사용자를 생성하려고 하면 + # Then: ValueError가 발생해야 한다. + def test_user_생성_without_이메일(self): + with pytest.raises(ValueError): + User.objects.create_user( + email="", password="testpassword", nickname="testnick" + ) + + # Given: CustomUser 모델이 존재할 때 + # When: 닉네임 없이 사용자를 생성하려고 하면 + # Then: ValueError가 발생해야 한다. + def test_user_생성_withdout_닉네임(self): + with pytest.raises(ValueError): + User.objects.create_user( + email="test@example.com", password="testpassword", nickname="" + ) + + # Given: CustomUser 모델이 존재할 때 + # When: is_staff=False로 스태프를 생성하려고 하면 + # Then: ValueError가 발생해야 한다. + def test_staff_생성_is_staff를_false값으로(self): + with pytest.raises(ValueError): + User.objects.create_staff( + email="staff@example.com", + password="staffpassword", + nickname="staffnick", + is_staff=False, + ) + + # Given: CustomUser 모델이 존재할 때 + # When: is_superuser=False로 슈퍼유저를 생성하려고 하면 + # Then: ValueError가 발생해야 한다. + def test_superuser_생성_is_superuser를_false값으로(self): + with pytest.raises(ValueError): + User.objects.create_superuser( + email="admin@example.com", + password="adminpassword", + nickname="adminnick", + is_superuser=False, + ) + + # Given: CustomUser 모델이 존재하고 튜터와 학생이 생성되어 있을 때 + # When: 튜터에게 학생을 할당하면 + # Then: 튜터-학생 관계가 정상적으로 설정되어야 한다. + def test_tutor에게_student를_할당해서_관계_생성(self): + tutor = User.objects.create_staff( + email="tutor@example.com", password="tutorpassword", nickname="tutornick" + ) + student = User.objects.create_user( + email="student@example.com", + password="studentpassword", + nickname="studentnick", + ) + tutor.students.add(student) + assert student in tutor.students.all() + assert tutor in student.tutors.all() + + # Given: CustomUser 모델이 존재하고 사용자가 생성되어 있을 때 + # When: 사용자의 문자열 표현을 요청하면 + # Then: 사용자의 이메일이 반환되어야 한다. + def test_user_str_representation_출력(self): + user = User.objects.create_user( + email="test@example.com", password="testpassword", nickname="testnick" + ) + assert str(user) == "test@example.com" + + # Given: CustomUser 모델이 존재할 때 + # When: 동일한 이메일로 두 명의 사용자를 생성하려고 하면 + # Then: IntegrityError가 발생해야 한다. (pytest-django에서는 TransactionManagement) + def test_email_기본키_constraint_만족(self): + User.objects.create_user( + email="test@example.com", password="testpassword1", nickname="testnick1" + ) + with pytest.raises(Exception): + User.objects.create_user( + email="test@example.com", password="testpassword2", nickname="testnick2" + ) + + # Given: CustomUser 모델이 존재할 때 + # When: 동일한 닉네임으로 두 명의 사용자를 생성하려고 하면 + # Then: IntegrityError가 발생해야 한다 + def test_nickname_후보키_constraint_만족(self): + User.objects.create_user( + email="test1@example.com", password="testpassword1", nickname="testnick" + ) + with pytest.raises(Exception): + User.objects.create_user( + email="test2@example.com", password="testpassword2", nickname="testnick" + ) diff --git a/accounts/test/test_accounts_permissions.py b/accounts/test/test_accounts_permissions.py new file mode 100644 index 0000000..09f6dbd --- /dev/null +++ b/accounts/test/test_accounts_permissions.py @@ -0,0 +1,162 @@ +import pytest +from accounts.permissions import IsAuthenticatedAndActive, IsSuperUser, IsTutor +from django.contrib.auth import get_user_model +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView + +User = get_user_model() + + +@pytest.fixture +def api_request_factory(): + return APIRequestFactory() + + +@pytest.fixture +def mock_view(): + return APIView() + + +@pytest.fixture +def create_user(): + def _create_user( + email, password, nickname, is_active=True, is_staff=False, is_superuser=False + ): + return User.objects.create_user( + email=email, + password=password, + nickname=nickname, + is_active=is_active, + is_staff=is_staff, + is_superuser=is_superuser, + ) + + return _create_user + + +@pytest.mark.django_db +class TestIsAuthenticatedAndActive: + # Given: 인증되고 활성화된 사용자가 있을 때 + # When: IsAuthenticatedAndActive 권한을 검사하면 + # Then: 권한이 허용되어야 합니다. + def test_authenticated_그리고_active_user_확인( + self, api_request_factory, mock_view, create_user + ): + user = create_user("user@example.com", "password", "nickname", is_active=True) + request = api_request_factory.get("/") + request.user = user + permission = IsAuthenticatedAndActive() + assert permission.has_permission(request, mock_view) is True + + # Given: 인증되었지만 비활성화된 사용자가 있을 때 + # When: IsAuthenticatedAndActive 권한을 검사하면 + # Then: 권한이 거부되어야 합니다. + def test_authenticated_그러나_inactive_user_확인( + self, api_request_factory, mock_view, create_user + ): + user = create_user("user@example.com", "password", "nickname", is_active=False) + request = api_request_factory.get("/") + request.user = user + permission = IsAuthenticatedAndActive() + assert permission.has_permission(request, mock_view) is False + + # Given: 인증되지 않은 사용자가 있을 때 + # When: IsAuthenticatedAndActive 권한을 검사하면 + # Then: 권한이 거부되어야 합니다. + def test_비_authenticated_user_확인(self, api_request_factory, mock_view): + request = api_request_factory.get("/") + request.user = None + permission = IsAuthenticatedAndActive() + assert permission.has_permission(request, mock_view) is False + + +@pytest.mark.django_db +class TestIsTutor: + # Given: 인증되고 활성화된 강사 사용자가 있을 때 + # When: IsTutor 권한을 검사하면 + # Then: 권한이 허용되어야 합니다. + def test_authenticated_active_tutor임을_확인( + self, api_request_factory, mock_view, create_user + ): + user = create_user( + "tutor@example.com", "password", "nickname", is_active=True, is_staff=True + ) + request = api_request_factory.get("/") + request.user = user + permission = IsTutor() + assert permission.has_permission(request, mock_view) is True + + # Given: 인증되고 활성화되었지만 강사가 아닌 사용자가 있을 때 + # When: IsTutor 권한을 검사하면 + # Then: 권한이 거부되어야 합니다. + def test_authenticated_active_그러나_non_tutor임을_확인( + self, api_request_factory, mock_view, create_user + ): + user = create_user( + "user@example.com", "password", "nickname", is_active=True, is_staff=False + ) + request = api_request_factory.get("/") + request.user = user + permission = IsTutor() + assert permission.has_permission(request, mock_view) is False + + # Given: 인증되지 않은 사용자가 있을 때 + # When: IsTutor 권한을 검사하면 + # Then: 권한이 거부되어야 합니다. + def test_비_authenticated_user_일때_tutor_검사_거부( + self, api_request_factory, mock_view + ): + request = api_request_factory.get("/") + request.user = None + permission = IsTutor() + assert permission.has_permission(request, mock_view) is False + + +@pytest.mark.django_db +class TestIsSuperUser: + # Given: 인증되고 활성화된 관리자 사용자가 있을 때 + # When: IsSuperUser 권한을 검사하면 + # Then: 권한이 허용되어야 합니다. + def test_authenticated_active_superuser임을_확인( + self, api_request_factory, mock_view, create_user + ): + user = create_user( + "admin@example.com", + "password", + "nickname", + is_active=True, + is_superuser=True, + ) + request = api_request_factory.get("/") + request.user = user + permission = IsSuperUser() + assert permission.has_permission(request, mock_view) is True + + # Given: 인증되고 활성화되었지만 관리자가 아닌 사용자가 있을 때 + # When: IsSuperUser 권한을 검사하면 + # Then: 권한이 거부되어야 합니다. + def test_authenticated_active_그러나_non_superuser임을_확인( + self, api_request_factory, mock_view, create_user + ): + user = create_user( + "user@example.com", + "password", + "nickname", + is_active=True, + is_superuser=False, + ) + request = api_request_factory.get("/") + request.user = user + permission = IsSuperUser() + assert permission.has_permission(request, mock_view) is False + + # Given: 인증되지 않은 사용자가 있을 때 + # When: IsSuperUser 권한을 검사하면 + # Then: 권한이 거부되어야 합니다. + def test_비_authenticated_user_일때_superuser_검사_거부( + self, api_request_factory, mock_view + ): + request = api_request_factory.get("/") + request.user = None + permission = IsSuperUser() + assert permission.has_permission(request, mock_view) is False diff --git a/accounts/test/test_accounts_serializers.py b/accounts/test/test_accounts_serializers.py new file mode 100644 index 0000000..6d94410 --- /dev/null +++ b/accounts/test/test_accounts_serializers.py @@ -0,0 +1,221 @@ +import pytest +from django.contrib.auth import get_user_model + +from accounts.serializers import ( + CustomUserDetailSerializer, + PasswordResetSerializer, + StudentListSerializer, + TutorListSerializer, + UserRegistrationSerializer, +) + +User = get_user_model() + + +@pytest.mark.django_db +class TestUserRegistrationSerializer: + # Given: 유효한 사용자 데이터가 주어졌을 때 + # When: UserRegistrationSerializer를 통해 데이터를 검증하면 + # Then: 데이터가 유효해야 합니다. + def test_valid_user_data(self): + data = { + "email": "test@example.com", + "password": "StrongPass1!", + "confirm_password": "StrongPass1!", + "nickname": "testuser", + } + serializer = UserRegistrationSerializer(data=data) + assert serializer.is_valid() + + # Given: 이미 존재하는 이메일로 데이터가 주어졌을 때 + # When: UserRegistrationSerializer를 통해 데이터를 검증하면 + # Then: 유효성 검사에 실패해야 합니다. + def test_duplicate_email(self): + User.objects.create_user( + email="existing@example.com", password="password", nickname="existing" + ) + data = { + "email": "existing@example.com", + "password": "NewPass1!", + "nickname": "newuser", + } + serializer = UserRegistrationSerializer(data=data) + assert not serializer.is_valid() + assert "email" in serializer.errors + + +@pytest.mark.django_db +class TestPasswordResetSerializer: + @pytest.fixture + def user(self): + return User.objects.create_user( + email="test@example.com", password="OldPass1!", nickname="testuser" + ) + + # Given: 유효한 비밀번호 재설정 데이터가 주어졌을 때 + # When: PasswordResetSerializer를 통해 데이터를 검증하면 + # Then: 데이터가 유효해야 합니다. + def test_valid_password_reset(self, user, rf): + request = rf.get("/") + request.user = user + data = { + "current_password": "OldPass1!", + "new_password": "NewStrongPass2@", + "confirm_new_password": "NewStrongPass2@", + } + serializer = PasswordResetSerializer(data=data, context={"request": request}) + assert serializer.is_valid() + + # Given: 현재 비밀번호가 틀린 데이터가 주어졌을 때 + # When: PasswordResetSerializer를 통해 데이터를 검증하면 + # Then: 유효성 검사에 실패해야 합니다. + def test_invalid_current_password(self, user, rf): + request = rf.get("/") + request.user = user + data = { + "current_password": "WrongPass1!", + "new_password": "NewStrongPass2@", + "confirm_new_password": "NewStrongPass2@", + } + serializer = PasswordResetSerializer(data=data, context={"request": request}) + assert not serializer.is_valid() + assert "current_password" in serializer.errors + + # Given: 새 비밀번호가 복잡성 요구사항을 충족하지 않을 때 + # When: PasswordResetSerializer를 통해 데이터를 검증하면 + # Then: 유효성 검사에 실패해야 합니다. + def test_weak_new_password(self, user, rf): + request = rf.get("/") + request.user = user + data = { + "current_password": "OldPass1!", + "new_password": "weak", + "confirm_new_password": "weak", + } + serializer = PasswordResetSerializer(data=data, context={"request": request}) + assert not serializer.is_valid() + assert "new_password" in serializer.errors + + +@pytest.mark.django_db +class TestStudentListSerializer: + # Given: 학생 사용자가 있을 때 + # When: StudentListSerializer를 사용해 직렬화하면 + # Then: 지정된 필드만 포함되어야 합니다. + def test_student_serialization(self): + student = User.objects.create_user( + email="student@example.com", password="Pass1!", nickname="student" + ) + serializer = StudentListSerializer(student) + assert set(serializer.data.keys()) == {"id", "email", "nickname", "created_at"} + + +@pytest.mark.django_db +class TestTutorListSerializer: + # Given: 강사 사용자가 있을 때 + # When: TutorListSerializer를 사용해 직렬화하면 + # Then: 지정된 필드만 포함되어야 합니다. + def test_tutor_serialization(self): + tutor = User.objects.create_user( + email="tutor@example.com", + password="Pass1!", + nickname="tutor", + is_staff=True, + ) + serializer = TutorListSerializer(tutor) + assert set(serializer.data.keys()) == {"id", "email", "nickname", "created_at"} + + +@pytest.mark.django_db +class TestCustomUserDetailSerializer: + @pytest.fixture + def student(self): + return User.objects.create_user( + email="student@example.com", password="Pass1!", nickname="student" + ) + + @pytest.fixture + def tutor(self): + return User.objects.create_user( + email="tutor@example.com", + password="Pass1!", + nickname="tutor", + is_staff=True, + ) + + @pytest.fixture + def admin(self): + return User.objects.create_superuser( + email="admin@example.com", password="Pass1!", nickname="admin" + ) + + # Given: 학생 사용자가 있을 때 + # When: CustomUserDetailSerializer를 사용해 직렬화하면 + # Then: 학생에 해당하는 필드만 포함되어야 합니다. + def test_student_serialization(self, student): + serializer = CustomUserDetailSerializer(student) + assert "user_type" in serializer.data + assert serializer.data["user_type"] == "student" + assert "student_count" not in serializer.data + assert "tutor_count" not in serializer.data + + # Given: 강사 사용자가 있을 때 + # When: CustomUserDetailSerializer를 사용해 직렬화하면 + # Then: 강사에 해당하는 필드만 포함되어야 합니다. + def test_tutor_serialization(self, tutor): + serializer = CustomUserDetailSerializer(tutor) + assert "user_type" in serializer.data + assert serializer.data["user_type"] == "tutor" + assert "student_count" in serializer.data + assert "tutor_count" not in serializer.data + + # Given: 관리자 사용자가 있을 때 + # When: CustomUserDetailSerializer를 사용해 직렬화하면 + # Then: 관리자에 해당하는 모든 필드가 포함되어야 합니다. + def test_admin_serialization(self, admin): + serializer = CustomUserDetailSerializer(admin) + assert "user_type" in serializer.data + assert serializer.data["user_type"] == "superuser" + assert "student_count" in serializer.data + assert "tutor_count" in serializer.data + + # Given: 유효한 사용자 데이터가 주어졌을 때 + # When: CustomUserDetailSerializer를 통해 데이터를 검증하면 + # Then: 데이터가 유효해야 합니다. + def test_valid_user_data(self): + data = { + "email": "new@example.com", + "password": "NewPass1!", + "confirm_password": "NewPass1!", + "nickname": "newuser", + } + serializer = CustomUserDetailSerializer(data=data) + assert serializer.is_valid() + + # Given: 비밀번호와 확인 비밀번호가 일치하지 않는 데이터가 주어졌을 때 + # When: CustomUserDetailSerializer를 통해 데이터를 검증하면 + # Then: 유효성 검사에 실패해야 합니다. + def test_password_mismatch(self): + data = { + "email": "new@example.com", + "password": "NewPass1!", + "confirm_password": "DifferentPass1!", + "nickname": "newuser", + } + serializer = CustomUserDetailSerializer(data=data) + assert not serializer.is_valid() + assert "non_field_errors" in serializer.errors + + # Given: 약한 비밀번호가 포함된 데이터가 주어졌을 때 + # When: CustomUserDetailSerializer를 통해 데이터를 검증하면 + # Then: 유효성 검사에 실패해야 합니다. + def test_weak_password(self): + data = { + "email": "new@example.com", + "password": "weak", + "confirm_password": "weak", + "nickname": "newuser", + } + serializer = CustomUserDetailSerializer(data=data) + assert not serializer.is_valid() + assert "password" in serializer.errors diff --git a/accounts/test/test_accounts_views.py b/accounts/test/test_accounts_views.py new file mode 100644 index 0000000..4d4c087 --- /dev/null +++ b/accounts/test/test_accounts_views.py @@ -0,0 +1,274 @@ +import pytest +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from accounts.models import CustomUser + + +@pytest.fixture +def api_client(): + return APIClient() + + +@pytest.fixture +def create_user(): + def _create_user(email, password, nickname, is_staff=False, is_superuser=False): + return CustomUser.objects.create_user( + email=email, + password=password, + nickname=nickname, + is_staff=is_staff, + is_superuser=is_superuser, + ) + + return _create_user + + +@pytest.mark.django_db +class TestUserRegisterationView: + # Given: 유효한 사용자 데이터가 주어졌을 때 + # When: 회원가입 요청을 보내면 + # Then: 성공적으로 회원가입이 되어야 합니다. + def test_user_registration_success(self, api_client): + url = reverse("accounts:student-register") + data = { + "email": "test@example.com", + "nickname": "testnick", + "password": "testpassword", + "confirm_password": "testpassword", + } + response = api_client.post(url, data) + assert response.status_code == status.HTTP_201_CREATED + assert "user_id" in response.data + assert CustomUser.objects.filter(email="test@example.com").exists() + + # Given: 이미 존재하는 이메일로 가입 시도할 때 + # When: 회원가입 요청을 보내면 + # Then: 회원가입이 실패해야 합니다. + def test_user_registration_duplicate_email(self, api_client, create_user): + create_user("test@example.com", "password123", "existinguser") + url = reverse("accounts:student-register") + data = { + "email": "test@example.com", + "nickname": "testnick", + "password": "testpassword", + "confirm_password": "testpassword", + } + response = api_client.post(url, data) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + +@pytest.mark.django_db +class TestPasswordResetView: + # Given: 인증된 사용자가 있을 때 + # When: 비밀번호 재설정 요청을 보내면 + # Then: 비밀번호가 성공적으로 변경되어야 합니다. + def test_password_reset_success(self, api_client, create_user): + user = create_user("test@example.com", "CurrentPassword123!", "testnick") + api_client.force_authenticate(user=user) + url = reverse("accounts:password-reset") + data = { + "current_password": "CurrentPassword123!", + "new_password": "NewPassword456@", + "confirm_new_password": "NewPassword456@", + } + response = api_client.post(url, data) + assert response.status_code == status.HTTP_200_OK + user.refresh_from_db() + assert user.check_password("NewPassword456@") + + # Given: 인증되지 않은 사용자일 때 + # When: 비밀번호 재설정 요청을 보내면 + # Then: 권한 오류가 발생해야 합니다. + def test_password_reset_unauthenticated(self, api_client): + url = reverse("accounts:password-reset") + data = { + "current_password": "currentpassword", + "new_password": "newpassword", + "confirm_new_password": "confirmnewpassword", + } + response = api_client.post(url, data) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +@pytest.mark.django_db +class TestStudentListView: + # Given: 여러 명의 학생이 존재하고 인증된 사용자가 있을 때 + # When: 학생 목록 조회 요청을 보내면 + # Then: 학생 목록이 정상적으로 반환되어야 합니다. + def test_student_list_view(self, api_client, create_user): + create_user("student1@example.com", "password", "student1") + create_user("student2@example.com", "password", "student2") + user = create_user("user@example.com", "password", "user") + api_client.force_authenticate(user=user) + url = reverse("accounts:student-list") + response = api_client.get(url) + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) == 3 + + # Given: 인증되지 않은 사용자일 때 + # When: 학생 목록 조회 요청을 보내면 + # Then: 권한 오류가 발생해야 합니다. + def test_student_list_view_unauthenticated(self, api_client): + url = reverse("accounts:student-list") + response = api_client.get(url) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +@pytest.mark.django_db +class TestStudentRetrieveUpdateDestroyView: + # Given: 학생 사용자가 존재할 때 + # When: 자신의 정보 조회 요청을 보내면 + # Then: 학생 정보가 정상적으로 반환되어야 합니다. + def test_student_retrieve(self, api_client, create_user): + student = create_user("student3@example.com", "password", "student") + api_client.force_authenticate(user=student) + url = reverse("accounts:student-detail", kwargs={"pk": student.pk}) + response = api_client.get(url) + assert response.status_code == status.HTTP_200_OK + assert response.data["email"] == "student3@example.com" + + # Given: 학생 사용자가 존재할 때 + # When: 자신의 정보 수정 요청을 보내면 + # Then: 학생 정보가 성공적으로 업데이트되어야 합니다. + # def test_student_update_full(self, api_client, create_user): + # student = create_user("student45@example.com", "OldPassword123!", "student") + # api_client.force_authenticate(user=student) + # url = reverse("accounts:student-detail", kwargs={"pk": student.pk}) + # data = { + # "email": "student45@example.com", + # "nickname": "new_nickname", + # "password": "NewPassword456!", + # "confirm_password": "NewPassword456!", + # } + # response = api_client.put(url, data) + # assert response.status_code == status.HTTP_200_OK + # student.refresh_from_db() + # assert student.email == "student45@example.com" + # assert student.nickname == "new_nickname" + # assert student.check_password("NewPassword456!") + + # Given: 학생 사용자가 존재할 때 + # When: 자신의 정보 수정 요청을 보내면 + # Then: 학생 정보가 성공적으로 업데이트되어야 합니다. + def test_student_update_partial(self, api_client, create_user): + student = create_user("student5@example.com", "password", "student") + api_client.force_authenticate(user=student) + url = reverse("accounts:student-detail", kwargs={"pk": student.pk}) + data = {"nickname": "new_nickname"} + response = api_client.patch(url, data) + assert response.status_code == status.HTTP_200_OK + student.refresh_from_db() + assert student.nickname == "new_nickname" + + # Given: 학생 사용자가 존재할 때 + # When: 자신의 계정 삭제 요청을 보내면 + # Then: 계정이 비활성화되어야 합니다. + def test_student_delete(self, api_client, create_user): + student = create_user("student@example.com", "password", "student") + api_client.force_authenticate(user=student) + url = reverse("accounts:student-detail", kwargs={"pk": student.pk}) + response = api_client.delete(url) + assert response.status_code == status.HTTP_204_NO_CONTENT + student.refresh_from_db() + assert not student.is_active + + +@pytest.mark.django_db +class TestTutorListView: + # Given: 여러 명의 강사가 존재하고 관리자가 있을 때 + # When: 강사 목록 조회 요청을 보내면 + # Then: 강사 목록이 정상적으로 반환되어야 합니다. + def test_tutor_list_view(self, api_client, create_user): + create_user("tutor1@example.com", "password", "tutor1", is_staff=True) + create_user("tutor2@example.com", "password", "tutor2", is_staff=True) + admin = create_user( + "admin@example.com", "password", "admin", is_staff=True, is_superuser=True + ) + api_client.force_authenticate(user=admin) + url = reverse("accounts:tutor-list") + response = api_client.get(url) + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) == 3 + + # Given: 일반 사용자일 때 + # When: 강사 목록 조회 요청을 보내면 + # Then: 권한 오류가 발생해야 합니다. + def test_tutor_list_view_not_admin(self, api_client, create_user): + user = create_user("user@example.com", "password", "user") + api_client.force_authenticate(user=user) + url = reverse("accounts:tutor-list") + response = api_client.get(url) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +@pytest.mark.django_db +class TestTutorRetrieveUpdateDestroyView: + # Given: 강사 사용자가 존재할 때 + # When: 자신의 정보 조회 요청을 보내면 + # Then: 강사 정보가 정상적으로 반환되어야 합니다. + def test_tutor_retrieve(self, api_client, create_user): + tutor = create_user("tutor@example.com", "password", "tutor", is_staff=True) + api_client.force_authenticate(user=tutor) + url = reverse("accounts:tutor-detail", kwargs={"pk": tutor.pk}) + response = api_client.get(url) + assert response.status_code == status.HTTP_200_OK + assert response.data["email"] == "tutor@example.com" + + # Given: 강사 사용자가 존재할 때 + # When: 자신의 정보 수정 요청을 보내면 + # Then: 강사 정보가 성공적으로 업데이트되어야 합니다. + # def test_tutor_update_full(self, api_client, create_user): + # tutor = create_user("tutor@example.com", "password", "tutor", is_staff=True) + # api_client.force_authenticate(user=tutor) + # url = reverse("accounts:tutor-detail", kwargs={"pk": tutor.pk}) + # data = { + # "email": "tutor@example.com", + # "nickname": "new_nickname", + # "password": "passwordXX123!", + # "confirm_password": "passwordXX123!", + # } + # response = api_client.put(url, data) + # assert response.status_code == status.HTTP_200_OK + # tutor.refresh_from_db() + # assert tutor.nickname == "new_nickname" + + # Given: 강사 사용자가 존재할 때 + # When: 자신의 정보 수정 요청을 보내면 + # Then: 강사 정보가 성공적으로 업데이트되어야 합니다. + def test_tutor_update_partial(self, api_client, create_user): + tutor = create_user("tutor@example.com", "password", "tutor", is_staff=True) + api_client.force_authenticate(user=tutor) + url = reverse("accounts:tutor-detail", kwargs={"pk": tutor.pk}) + data = {"nickname": "new_nickname"} + response = api_client.patch(url, data) + assert response.status_code == status.HTTP_200_OK + tutor.refresh_from_db() + assert tutor.nickname == "new_nickname" + + # Given: 강사 사용자가 존재할 때 + # When: 관리자가 강사 계정 삭제 요청을 보내면 + # Then: 강사 계정이 비활성화되어야 합니다. + def test_tutor_delete_by_admin(self, api_client, create_user): + tutor = create_user("tutor@example.com", "password", "tutor", is_staff=True) + admin = create_user( + "admin@example.com", "password", "admin", is_staff=True, is_superuser=True + ) + api_client.force_authenticate(user=admin) + url = reverse("accounts:tutor-detail", kwargs={"pk": tutor.pk}) + response = api_client.delete(url) + assert response.status_code == status.HTTP_204_NO_CONTENT + tutor.refresh_from_db() + assert not tutor.is_active + + # Given: 강사 사용자가 존재할 때 + # When: 다른 강사가 해당 강사의 정보 조회를 시도하면 + # Then: 권한 오류가 발생해야 합니다. + def test_tutor_retrieve_other_tutor(self, api_client, create_user): + tutor1 = create_user("tutor1@example.com", "password", "tutor1", is_staff=True) + tutor2 = create_user("tutor2@example.com", "password", "tutor2", is_staff=True) + api_client.force_authenticate(user=tutor2) + url = reverse("accounts:tutor-detail", kwargs={"pk": tutor1.pk}) + response = api_client.get(url) + assert response.status_code == status.HTTP_403_FORBIDDEN diff --git a/accounts/tests/test_accounts_models.py b/accounts/tests/test_accounts_models.py deleted file mode 100644 index 1c597d3..0000000 --- a/accounts/tests/test_accounts_models.py +++ /dev/null @@ -1,40 +0,0 @@ -import pytest -from accounts.models import CustomUser -from django.core.exceptions import ValidationError - - -@pytest.mark.django_db -def test_create_user(): - user = CustomUser.objects.create_user( - email="test@example.com", password="testpass123" - ) - assert user.email == "test@example.com" - assert user.is_active - assert not user.is_staff - assert not user.is_superuser - - -@pytest.mark.django_db -def test_create_superuser(): - admin = CustomUser.objects.create_superuser( - email="admin@example.com", password="adminpass123" - ) - assert admin.email == "admin@example.com" - assert admin.is_active - assert admin.is_staff - assert admin.is_superuser - - -@pytest.mark.django_db -def test_user_str(): - user = CustomUser.objects.create_user( - email="test@example.com", password="testpass123" - ) - assert str(user) == "test@example.com" - - -@pytest.mark.django_db -def test_clean_method(): - user = CustomUser(password="testpass123") - with pytest.raises(ValidationError): - user.clean() diff --git a/accounts/tests/test_accounts_views_student.py b/accounts/tests/test_accounts_views_student.py deleted file mode 100644 index 4d166d9..0000000 --- a/accounts/tests/test_accounts_views_student.py +++ /dev/null @@ -1,59 +0,0 @@ -import pytest -from accounts.models import CustomUser -from rest_framework.test import APIClient - - -@pytest.mark.django_db -class TestStudentListCreateView: - def setup_method(self, method): - self.client = APIClient() - - def test_list_students(self): - CustomUser.objects.create_user( - email="student1@example.com", password="pass123", is_staff=False - ) - CustomUser.objects.create_user( - email="student2@example.com", password="pass123", is_staff=False - ) - - response = self.client.get("/api/students/") - assert response.status_code == 200 - assert len(response.data["results"]) == 2 - - def test_create_student(self): - data = { - "email": "newstudent@example.com", - "password": "NewPass123!", - "confirm_password": "NewPass123!", - "nickname": "NewStudent", - } - response = self.client.post("/api/students/", data) - assert response.status_code == 201 - assert CustomUser.objects.filter(email="newstudent@example.com").exists() - - -@pytest.mark.django_db -class TestStudentRetrieveUpdateDestroyView: - def setup_method(self, method): - self.client = APIClient() - self.user = CustomUser.objects.create_user( - email="student@example.com", password="pass123", is_staff=False - ) - self.client.force_authenticate(user=self.user) - - def test_retrieve_student(self): - response = self.client.get(f"/api/students/{self.user.id}/") - assert response.status_code == 200 - assert response.data["email"] == "student@example.com" - - def test_update_student(self): - data = {"nickname": "UpdatedNickname"} - response = self.client.put(f"/api/students/{self.user.id}/", data) - assert response.status_code == 200 - assert response.data["nickname"] == "UpdatedNickname" - - def test_delete_student(self): - response = self.client.delete(f"/api/students/{self.user.id}/") - assert response.status_code == 204 - self.user.refresh_from_db() - assert not self.user.is_active diff --git a/accounts/tests/test_accounts_views_tutor.py b/accounts/tests/test_accounts_views_tutor.py deleted file mode 100644 index 7a9d100..0000000 --- a/accounts/tests/test_accounts_views_tutor.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest -from accounts.models import CustomUser -from rest_framework.test import APIClient - - -@pytest.mark.django_db -class TestTutorStudentView: - def setup_method(self, method): - self.client = APIClient() - self.tutor = CustomUser.objects.create_user( - email="tutor@example.com", password="pass123", is_staff=True - ) - self.student1 = CustomUser.objects.create_user( - email="student1@example.com", password="pass123", is_staff=False - ) - self.student2 = CustomUser.objects.create_user( - email="student2@example.com", password="pass123", is_staff=False - ) - self.tutor.students.add(self.student1, self.student2) - self.client.force_authenticate(user=self.tutor) - - def test_list_tutor_students(self): - response = self.client.get(f"/api/tutors/{self.tutor.id}/students/") - assert response.status_code == 200 - assert response.data["student_count"] == 2 - assert len(response.data["students"]) == 2 - - def test_unauthorized_access(self): - other_tutor = CustomUser.objects.create_user( - email="other@example.com", password="pass123", is_staff=True - ) - self.client.force_authenticate(user=other_tutor) - response = self.client.get(f"/api/tutors/{self.tutor.id}/students/") - assert response.status_code == 403 diff --git a/accounts/urls.py b/accounts/urls.py index 0a21cb0..0f87bdc 100644 --- a/accounts/urls.py +++ b/accounts/urls.py @@ -2,32 +2,28 @@ from .views import ( PasswordResetView, - StudentListCreateView, + StudentListView, StudentRetrieveUpdateDestroyView, - TutorListCreateView, + TutorListView, TutorRetrieveUpdateDestroyView, - TutorStudentView, + UserRegisterationView, ) app_name = "accounts" urlpatterns = [ + path("student/register/", UserRegisterationView.as_view(), name="student-register"), path("password/reset/", PasswordResetView.as_view(), name="password-reset"), - path("students/", StudentListCreateView.as_view(), name="student-list-create"), + path("students/", StudentListView.as_view(), name="student-list"), path( "students//", StudentRetrieveUpdateDestroyView.as_view(), name="student-detail", ), - path("tutors/", TutorListCreateView.as_view(), name="tutor-list-create"), + path("tutors/", TutorListView.as_view(), name="tutor-list"), path( "tutors//", TutorRetrieveUpdateDestroyView.as_view(), name="tutor-detail", ), - path( - "tutors//students/", - TutorStudentView.as_view(), - name="tutor-student", - ), ] diff --git a/accounts/views.py b/accounts/views.py index 42831da..09f7f09 100644 --- a/accounts/views.py +++ b/accounts/views.py @@ -1,20 +1,20 @@ +from django.db import transaction +from django.db.utils import DatabaseError from django.shortcuts import get_object_or_404 -from jwtauth.authentication import JWTAuthentication -from rest_framework import generics, status -from rest_framework.exceptions import NotFound, PermissionDenied -from rest_framework.filters import OrderingFilter +from rest_framework import filters, generics, mixins, status +from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.pagination import PageNumberPagination -from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from .models import CustomUser -from .permissions import ( - IsAuthenticatedOrCreateOnly, - IsSuperUser, - IsTutor, - IsTutorOrSuperUserOrSuperUserCreateOnly, +from .permissions import IsAuthenticatedAndActive, IsSuperUser, IsTutor +from .serializers import ( + CustomUserDetailSerializer, + PasswordResetSerializer, + StudentListSerializer, + TutorListSerializer, + UserRegistrationSerializer, ) -from .serializers import CustomUserSerializer, PasswordResetSerializer class StandardResultsSetPagination(PageNumberPagination): @@ -27,190 +27,334 @@ class StandardResultsSetPagination(PageNumberPagination): max_page_size = 100 -class PasswordResetView(generics.GenericAPIView): +class UserRegisterationView(mixins.CreateModelMixin, generics.GenericAPIView): """ - 비밀번호를 재설정합니다. - POST: 비밀번호 재설정 (PUT이 아닌) + 회원가입을 위한 뷰입니다. + - POST: 학생 생성 + - 회원가입 또는 유효성 검사에 실패했을 때 에러메시지를 출력합니다. """ - serializer_class = PasswordResetSerializer - permission_classes = [IsAuthenticated] - authentication_classes = [JWTAuthentication] + queryset = CustomUser.objects.all() + + serializer_class = UserRegistrationSerializer def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - serializer.save() - return Response( - {"detail": "비밀번호가 성공적으로 변경되었습니다."}, - status=status.HTTP_200_OK, + + if serializer.is_valid(): + try: + user = self.perform_create(serializer) + headers = self.get_success_headers(serializer.data) + return Response( + { + "message": "성공적으로 회원가입 되었습니다.", + "user_id": user.id, + }, + status=status.HTTP_201_CREATED, + headers=headers, + ) + except Exception as e: + return Response( + { + "message": "회원가입에 실패했습니다. 다시 시도해주세요.", + "error": str(e), + }, + status=status.HTTP_400_BAD_REQUEST, + ) + else: + return Response( + { + "message": "실패했습니다. 다시 시도해주세요.", + "errors": serializer.errors, + }, + status=status.HTTP_400_BAD_REQUEST, + ) + + def perform_create(self, serializer): + user = CustomUser.objects.create_user( + email=serializer.validated_data["email"], + password=serializer.validated_data["password"], + nickname=serializer.validated_data["nickname"], ) + return user -class StudentListCreateView( - generics.GenericAPIView, - generics.mixins.ListModelMixin, - generics.mixins.CreateModelMixin, -): +class PasswordResetView(generics.GenericAPIView): + """ + 비밀번호를 재설정합니다. + - POST: 비밀번호 재설정 + """ + + serializer_class = PasswordResetSerializer + permission_classes = [IsAuthenticatedAndActive] + + def post(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + if serializer.is_valid(): + try: + serializer.save() + return Response( + {"detail": "비밀번호가 성공적으로 변경되었습니다."}, + status=status.HTTP_200_OK, + ) + except Exception as e: + return Response( + { + "error": "비밀번호 재설정 중 오류가 발생했습니다. 다시 시도해 주세요." + }, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + else: + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class StudentListView(mixins.ListModelMixin, generics.GenericAPIView): """ - 학생 유저를 목록 조회 및 생성합니다. + 학생 유저를 목록 조회합니다. + - 인증된 사용자만이 접근가능합니다. - GET: 학생 목록 조회 - - POST: 학생 생성 """ queryset = CustomUser.objects.filter(is_staff=False, is_active=True) - serializer_class = CustomUserSerializer - permission_classes = [IsAuthenticatedOrCreateOnly] # 자동 403 Forbidden + serializer_class = StudentListSerializer + permission_classes = [IsAuthenticatedAndActive] pagination_class = StandardResultsSetPagination - authentication_classes = JWTAuthentication # 자동 401 Unauthorized - ordering_fields = ["email", "first_name", "last_name", "date_joined"] - ordering = ["-date_joined"] # 기본 정렬 순서 + filter_backends = [filters.OrderingFilter] + ordering_fields = [ + "email", + "nickname", + "created_at", + ] # 클라이언트가 정렬에 사용할 수 있는 필드들을 지정 + ordering = ["-created_at"] # 기본 정렬 순서 def get(self, request, *args, **kwargs): - return self.list(request, *args, **kwargs) - - def post(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.data) - if serializer.is_valid(): - serializer.save(is_staff=False, is_active=True) - return Response(serializer.data, status=status.HTTP_201_CREATED) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + try: + return self.list(request, *args, **kwargs) + except PermissionDenied: + return Response( + {"error": "이 목록을 조회할 권한이 없습니다."}, + status=status.HTTP_403_FORBIDDEN, + ) + except DatabaseError: + return Response( + { + "error": "데이터베이스 오류가 발생했습니다. 잠시 후 다시 시도해주세요." + }, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + except Exception as e: + return Response( + {"error": f"학생 목록 조회 중 오류가 발생했습니다: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) class StudentRetrieveUpdateDestroyView( - generics.GenericAPIView, generics.mixins.RetrieveModelMixin, generics.mixins.UpdateModelMixin, generics.mixins.DestroyModelMixin, + generics.GenericAPIView, ): """ - 학생 유저에 대해 조회, 수정, 삭제합니다. - GET: 학생 상세 조회 - PUT: 학생 정보 수정 - DELETE: 학생 소프트 삭제 + 학생이 학생 정보를 조회, 수정, 삭제합니다. + - GET: 학생 상세 조회 + - PUT: 학생 정보 수정 + - DELETE: 학생 삭제(소프트 삭제) """ - queryset = CustomUser.objects.filter(is_staff=False) - serializer_class = CustomUserSerializer - permission_classes = [IsAuthenticated] - authentication_classes = JWTAuthentication + queryset = CustomUser.objects.filter(is_staff=False, is_active=True) + serializer_class = CustomUserDetailSerializer + permission_classes = [IsAuthenticatedAndActive] def get_object(self): obj = get_object_or_404(self.get_queryset(), pk=self.kwargs["pk"]) if self.request.user.pk != obj.pk: - raise PermissionDenied("해당 사용자는 권한이 없는 접근입니다.") + raise PermissionDenied("해당 사용자의 정보에 접근할 권한이 없습니다.") + self.check_object_permissions(self.request, obj) return obj def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) + try: + return self.retrieve(request, *args, **kwargs) + except PermissionDenied as e: + return Response({"error": str(e)}, status=status.HTTP_403_FORBIDDEN) + except Exception as e: + return Response( + {"error": f"조회 중 오류가 발생했습니다: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) def put(self, request, *args, **kwargs): - return self.update(request, *args, **kwargs) + return self._update(request, *args, **kwargs) - def delete(self, request, *args, **kwargs): - user = self.get_object() - user.is_active = False - user.save() - return Response(status=status.HTTP_204_NO_CONTENT) + def patch(self, request, *args, **kwargs): + return self._update(request, *args, partial=True, **kwargs) + + def _update(self, request, *args, partial=False, **kwargs): + try: + instance = self.get_object() + serializer = self.get_serializer( + instance, data=request.data, partial=partial + ) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + return Response(serializer.data, status=status.HTTP_200_OK) + except PermissionDenied as e: + return Response({"error": str(e)}, status=status.HTTP_403_FORBIDDEN) + except ValidationError as e: + return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + except Exception as e: + return Response( + {"error": f"수정 중 오류가 발생했습니다: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + def perform_update(self, serializer): + serializer.save() + def update(self, instance, validated_data): + password = validated_data.pop("password", None) + if password: + instance.set_password(password) + return super().update(instance, validated_data) -class TutorListCreateView( + @transaction.atomic + def delete(self, request, *args, **kwargs): + try: + user = self.get_object() + user.is_active = False + user.save() + return Response( + {"message": "계정이 비활성화되었습니다."}, + status=status.HTTP_204_NO_CONTENT, + ) + except PermissionDenied as e: + return Response({"error": str(e)}, status=status.HTTP_403_FORBIDDEN) + except Exception as e: + return Response( + {"error": f"삭제 중 오류가 발생했습니다: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +class TutorListView( + mixins.ListModelMixin, generics.GenericAPIView, - generics.mixins.ListModelMixin, - generics.mixins.CreateModelMixin, ): """ - 관리자 사용자 목록 조회 및 생성합니다. - GET: 관리자 목록 조회 - POST: 관리자 생성 + 관리자가 강사 목록을 조회합니다. + - GET: 강사 목록 조회 """ queryset = CustomUser.objects.filter(is_staff=True, is_active=True) - serializer_class = CustomUserSerializer - permission_classes = [IsTutorOrSuperUserOrSuperUserCreateOnly] + serializer_class = TutorListSerializer + permission_classes = [IsSuperUser] pagination_class = StandardResultsSetPagination - authentication_classes = JWTAuthentication - ordering_fields = ["email", "first_name", "last_name", "date_joined"] - ordering = ["-date_joined"] # 기본 정렬 순서 + filter_backends = [filters.OrderingFilter] + ordering_fields = ["email", "nickname", "created_at"] + ordering = ["-created_at"] # 기본 정렬 순서 def get(self, request, *args, **kwargs): - return self.list(request, *args, **kwargs) - - def post(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.data) - if serializer.is_valid(): - serializer.save(is_staff=True, is_active=True) - return Response(serializer.data, status=status.HTTP_201_CREATED) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + try: + return self.list(request, *args, **kwargs) + except PermissionDenied: + return Response( + {"error": "이 목록을 조회할 권한이 없습니다."}, + status=status.HTTP_403_FORBIDDEN, + ) + except DatabaseError: + return Response( + { + "error": "데이터베이스 오류가 발생했습니다. 잠시 후 다시 시도해주세요." + }, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + except Exception as e: + return Response( + {"error": f"강사 목록 조회 중 오류가 발생했습니다: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) class TutorRetrieveUpdateDestroyView( - generics.GenericAPIView, generics.mixins.RetrieveModelMixin, generics.mixins.UpdateModelMixin, generics.mixins.DestroyModelMixin, + generics.GenericAPIView, ): """ - 관리자 사용자 조회, 수정, 삭제합니다. - GET: 관리자 상세 조회 - PUT: 관리자 정보 수정 - DELETE: 관리자 소프트 삭제 + 관리자나 강사가 강사 정보를 조회, 수정, 삭제합니다. + - GET: 강사 상세 조회 + - PUT: 강사 정보 수정 + - DELETE: 강사 삭제 (소프트 삭제) """ - queryset = CustomUser.objects.filter(is_staff=True) - serializer_class = CustomUserSerializer + queryset = CustomUser.objects.filter(is_staff=True, is_active=True) + serializer_class = CustomUserDetailSerializer permission_classes = [IsTutor | IsSuperUser] - authentication_classes = JWTAuthentication + + def get_object(self): + obj = get_object_or_404(self.get_queryset(), pk=self.kwargs["pk"]) + self.check_object_permissions(self.request, obj) + return obj def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) + try: + return self.retrieve(request, *args, **kwargs) + except PermissionDenied as e: + return Response({"error": str(e)}, status=status.HTTP_403_FORBIDDEN) + except Exception as e: + return Response( + {"error": f"조회 중 오류가 발생했습니다: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) def put(self, request, *args, **kwargs): - return self.update(request, *args, **kwargs) - - def delete(self, request, *args, **kwargs): - tutor = self.get_object() - tutor.is_active = False - tutor.save() - return Response(status=status.HTTP_204_NO_CONTENT) + return self._update(request, *args, **kwargs) + def patch(self, request, *args, **kwargs): + return self._update(request, *args, partial=True, **kwargs) -class TutorStudentView(generics.ListAPIView): - """ - 특정 튜터의 학생 목록을 조회합니다. - GET: 튜터의 학생 목록 조회 - """ - - serializer_class = CustomUserSerializer - permission_classes = [IsAuthenticated & (IsTutor | IsSuperUser)] - authentication_classes = [JWTAuthentication] - pagination_class = StandardResultsSetPagination - filter_backends = [OrderingFilter] - - ordering_fields = ["email", "first_name", "last_name", "created_at"] - ordering = ["created_at"] # 기본 정렬 순서 - - def get_queryset(self): - tutor_id = self.kwargs.get("tutor_id") + def _update(self, request, *args, partial=False, **kwargs): try: - tutor = CustomUser.objects.get(id=tutor_id, is_staff=True) - except CustomUser.DoesNotExist: - raise NotFound("해당 튜터를 찾을 수 없습니다.") - - if self.request.user.id != tutor_id and not self.request.user.is_superuser: - raise PermissionDenied("이 정보에 접근할 권한이 없습니다.") - - return tutor.students.filter(is_active=True) - - def list(self, request, *args, **kwargs): - queryset = self.filter_queryset(self.get_queryset()) - serializer = self.get_serializer(queryset, many=True) - return Response( - { - "tutor_id": self.kwargs.get("tutor_id"), - "student_count": queryset.count(), - "students": serializer.data, - } - ) + instance = self.get_object() + serializer = self.get_serializer( + instance, data=request.data, partial=partial + ) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + return Response(serializer.data) + except PermissionDenied as e: + return Response({"error": str(e)}, status=status.HTTP_403_FORBIDDEN) + except ValidationError as e: + return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + except Exception as e: + return Response( + {"error": f"수정 중 오류가 발생했습니다: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + @transaction.atomic + def delete(self, request, *args, **kwargs): + try: + tutor = self.get_object() + tutor.is_active = False + tutor.save() + return Response( + {"message": "강사 계정이 비활성화되었습니다."}, + status=status.HTTP_204_NO_CONTENT, + ) + except PermissionDenied as e: + return Response({"error": str(e)}, status=status.HTTP_403_FORBIDDEN) + except Exception as e: + return Response( + {"error": f"삭제 중 오류가 발생했습니다: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + def check_object_permissions(self, request, obj): + if not request.user.is_superuser and request.user.pk != obj.pk: + raise PermissionDenied("해당 강사의 정보에 접근할 권한이 없습니다.") + super().check_object_permissions(request, obj) diff --git a/courses/admin.py b/courses/admin.py index 7f4bd86..452f0cf 100644 --- a/courses/admin.py +++ b/courses/admin.py @@ -4,6 +4,7 @@ from .models import ( Assignment, Course, + Curriculum, Lecture, MultipleChoiceQuestion, MultipleChoiceQuestionChoice, @@ -11,9 +12,10 @@ ) if settings.DEBUG: - admin.register(Course) - admin.register(Lecture) - admin.register(Topic) - admin.register(Assignment) - admin.register(MultipleChoiceQuestion) - admin.register(MultipleChoiceQuestionChoice) + admin.site.register(Course) + admin.site.register(Lecture) + admin.site.register(Topic) + admin.site.register(Assignment) + admin.site.register(MultipleChoiceQuestion) + admin.site.register(MultipleChoiceQuestionChoice) + admin.site.register(Curriculum) diff --git a/courses/apps.py b/courses/apps.py index 85987be..1a052c5 100644 --- a/courses/apps.py +++ b/courses/apps.py @@ -2,5 +2,9 @@ class CoursesConfig(AppConfig): + """ + Course 앱의 설정 클래스입니다. + """ + default_auto_field = "django.db.models.BigAutoField" name = "courses" diff --git a/courses/migrations/0006_alter_course_curriculum.py b/courses/migrations/0006_alter_course_curriculum.py new file mode 100644 index 0000000..9e8b7fe --- /dev/null +++ b/courses/migrations/0006_alter_course_curriculum.py @@ -0,0 +1,19 @@ +# Generated by Django 5.1.1 on 2024-10-09 09:08 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('courses', '0005_alter_assignment_options_alter_course_options_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='course', + name='curriculum', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='courses', to='courses.curriculum', verbose_name='커리큘럼'), + ), + ] diff --git a/courses/migrations/0007_alter_course_category.py b/courses/migrations/0007_alter_course_category.py new file mode 100644 index 0000000..447c401 --- /dev/null +++ b/courses/migrations/0007_alter_course_category.py @@ -0,0 +1,18 @@ +# Generated by Django 5.1.1 on 2024-10-09 13:38 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('courses', '0006_alter_course_curriculum'), + ] + + operations = [ + migrations.AlterField( + model_name='course', + name='category', + field=models.CharField(choices=[('JavaScript', 'JavaScript'), ('Python', 'Python'), ('Django', 'Django'), ('React', 'React'), ('Vue', 'Vue'), ('Node', 'Node'), ('AWS', 'AWS'), ('Docker', 'Docker'), ('DB', 'DB')], default='JavaScript', max_length=255, verbose_name='카테고리'), + ), + ] diff --git a/courses/migrations/0008_rename_course_level_course_skill_level_and_more.py b/courses/migrations/0008_rename_course_level_course_skill_level_and_more.py new file mode 100644 index 0000000..166153d --- /dev/null +++ b/courses/migrations/0008_rename_course_level_course_skill_level_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 5.1.1 on 2024-10-10 02:02 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('courses', '0007_alter_course_category'), + ] + + operations = [ + migrations.RenameField( + model_name='course', + old_name='course_level', + new_name='skill_level', + ), + migrations.AddField( + model_name='curriculum', + name='category', + field=models.CharField(choices=[('JavaScript', 'JavaScript'), ('Python', 'Python'), ('Django', 'Django'), ('React', 'React'), ('Vue', 'Vue'), ('Node', 'Node'), ('AWS', 'AWS'), ('Docker', 'Docker'), ('DB', 'DB')], default='JavaScript', max_length=255, verbose_name='카테고리'), + ), + migrations.AddField( + model_name='curriculum', + name='skill_level', + field=models.CharField(choices=[('beginner', '초급'), ('intermediate', '중급'), ('advanced', '고급')], default='beginner', max_length=255, verbose_name='난이도'), + ), + ] diff --git a/courses/migrations/0009_course_author_curriculum_author.py b/courses/migrations/0009_course_author_curriculum_author.py new file mode 100644 index 0000000..ba274bb --- /dev/null +++ b/courses/migrations/0009_course_author_curriculum_author.py @@ -0,0 +1,28 @@ +# Generated by Django 5.1.1 on 2024-10-10 02:31 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('courses', '0008_rename_course_level_course_skill_level_and_more'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AddField( + model_name='course', + name='author', + field=models.ForeignKey(default=1, on_delete=django.db.models.deletion.CASCADE, related_name='courses', to=settings.AUTH_USER_MODEL, verbose_name='작성자'), + preserve_default=False, + ), + migrations.AddField( + model_name='curriculum', + name='author', + field=models.ForeignKey(default=1, on_delete=django.db.models.deletion.CASCADE, related_name='curriculums', to=settings.AUTH_USER_MODEL, verbose_name='작성자'), + preserve_default=False, + ), + ] diff --git a/courses/migrations/0010_remove_topic_description.py b/courses/migrations/0010_remove_topic_description.py new file mode 100644 index 0000000..4fbcdcf --- /dev/null +++ b/courses/migrations/0010_remove_topic_description.py @@ -0,0 +1,17 @@ +# Generated by Django 5.1.1 on 2024-10-13 08:53 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('courses', '0009_course_author_curriculum_author'), + ] + + operations = [ + migrations.RemoveField( + model_name='topic', + name='description', + ), + ] diff --git a/courses/mixins.py b/courses/mixins.py index 0995e34..f813cdb 100644 --- a/courses/mixins.py +++ b/courses/mixins.py @@ -1,5 +1,7 @@ from django.db import transaction +from materials.models import Image, Video + from .models import ( Assignment, Course, @@ -16,12 +18,14 @@ class CourseMixin: """ @transaction.atomic - def create_course_with_lectures_and_topics(self, course_data, lectures_data): + def create_course_with_lectures_and_topics( + self, course_data, lectures_data, author + ): """ course 및 하위 모델 lecture, topic, assignment, quiz 등을 함께 생성합니다. """ - course = self._create_course(course_data) + course = self._create_course(course_data, author) for lecture_data in lectures_data: lecture = self._create_lecture(lecture_data, course) for topic_data in lecture_data.get("topics", []): @@ -46,19 +50,24 @@ def update_course_with_lectures_and_topics( topic = self._create_topic(topic_data, lecture) self._handle_topic_type(topic, topic_data) - def _create_course(self, course_data): + def _create_course(self, course_data, author): """ course 인스턴스를 생성합니다. """ - return Course.objects.create( + course = Course.objects.create( title=course_data.get("title"), short_description=course_data.get("short_description"), description=course_data.get("description"), category=course_data.get("category"), - course_level=course_data.get("course_level"), + skill_level=course_data.get("skill_level"), price=course_data.get("price"), + author=author, ) + Image.objects.filter(id=course_data.get("thumbnail_id")).update(course=course) + Video.objects.filter(id=course_data.get("video_id")).update(course=course) + + return course def _create_lecture(self, lecture_data, course): """ @@ -76,14 +85,15 @@ def _create_topic(self, topic_data, lecture): topic 인스턴스를 생성합니다. """ - return Topic.objects.create( + topic = Topic.objects.create( lecture=lecture, title=topic_data.get("title"), type=topic_data.get("type"), - description=topic_data.get("description"), order=topic_data.get("order"), is_premium=topic_data.get("is_premium"), ) + Video.objects.filter(id=topic_data.get("video_id")).update(topic=topic) + return topic def _handle_topic_type(self, topic, topic_data): """ diff --git a/courses/models.py b/courses/models.py index 974bd3c..e373f7c 100644 --- a/courses/models.py +++ b/courses/models.py @@ -1,10 +1,50 @@ +from django.conf import settings from django.db import models class Curriculum(models.Model): + """ + 커리큘럼 모델입니다. + """ + + category_choices = [ + ("JavaScript", "JavaScript"), + ("Python", "Python"), + ("Django", "Django"), + ("React", "React"), + ("Vue", "Vue"), + ("Node", "Node"), + ("AWS", "AWS"), + ("Docker", "Docker"), + ("DB", "DB"), + ] + skill_level_choices = [ + ("beginner", "초급"), + ("intermediate", "중급"), + ("advanced", "고급"), + ] + + author = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + related_name="curriculums", + verbose_name="작성자", + ) name = models.CharField(max_length=255, verbose_name="커리큘럼 이름") description = models.TextField(verbose_name="설명") price = models.PositiveIntegerField(verbose_name="가격") + category = models.CharField( + max_length=255, + verbose_name="카테고리", + choices=category_choices, + default="JavaScript", + ) + skill_level = models.CharField( + max_length=255, + verbose_name="난이도", + choices=skill_level_choices, + default="beginner", + ) created_at = models.DateTimeField(auto_now_add=True, verbose_name="생성일") updated_at = models.DateTimeField(auto_now=True, verbose_name="수정일") @@ -18,18 +58,22 @@ class Meta: class Course(models.Model): + """ + 코스 모델입니다. + """ + category_choices = [ ("JavaScript", "JavaScript"), ("Python", "Python"), ("Django", "Django"), ("React", "React"), - ("Vue.js", "Vue.js"), - ("Node.js", "Node.js"), + ("Vue", "Vue"), + ("Node", "Node"), ("AWS", "AWS"), ("Docker", "Docker"), ("DB", "DB"), ] - course_level_choices = [ + skill_level_choices = [ ("beginner", "초급"), ("intermediate", "중급"), ("advanced", "고급"), @@ -39,9 +83,16 @@ class Course(models.Model): Curriculum, on_delete=models.SET_NULL, null=True, + blank=True, related_name="courses", verbose_name="커리큘럼", ) + author = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + related_name="courses", + verbose_name="작성자", + ) title = models.CharField(max_length=255, verbose_name="코스 제목") short_description = models.TextField(verbose_name="간단한 설명") description = models.JSONField(verbose_name="설명") @@ -51,24 +102,39 @@ class Course(models.Model): choices=category_choices, default="JavaScript", ) - course_level = models.CharField( + skill_level = models.CharField( max_length=255, verbose_name="난이도", - choices=course_level_choices, + choices=skill_level_choices, default="beginner", ) price = models.PositiveIntegerField(verbose_name="가격") created_at = models.DateTimeField(auto_now_add=True, verbose_name="생성일") updated_at = models.DateTimeField(auto_now=True, verbose_name="수정일") + def get_thumbnail(self): + if hasattr(self, "images") and self.images.exists(): + return self.images.first().file.url + return "https://www.gravatar.com/avatar/205e460b479e2e5b48aec077" + def update(self, **kwargs): + """ + 코스 정보를 수정합니다. + 수정 가능한 필드: + - title: 코스 제목 + - short_description: 간단한 설명 + - description: 설명 + - category: 카테고리 + - skill_level: 난이도 + - price: 가격 + """ for key, value in kwargs.items(): if key not in [ "title", "short_description", "description", "category", - "course_level", + "skill_level", "price", ]: continue @@ -85,6 +151,10 @@ class Meta: class Lecture(models.Model): + """ + 강의 모델입니다. + """ + course = models.ForeignKey( Course, on_delete=models.CASCADE, related_name="lectures", verbose_name="코스" ) @@ -94,7 +164,7 @@ class Lecture(models.Model): updated_at = models.DateTimeField(auto_now=True, verbose_name="수정일") def __str__(self): - return f"{self.course.title} - {self.title}" + return f"{self.id} - {self.title}" class Meta: ordering = ["order"] @@ -103,6 +173,10 @@ class Meta: class Topic(models.Model): + """ + 주제 모델입니다. + """ + topic_type_choices = [ ("video", "동영상"), ("article", "글"), @@ -120,14 +194,13 @@ class Topic(models.Model): choices=topic_type_choices, default="video", ) - description = models.TextField(verbose_name="설명") order = models.PositiveIntegerField(verbose_name="순서") is_premium = models.BooleanField(verbose_name="프리미엄 여부", default=False) created_at = models.DateTimeField(auto_now_add=True, verbose_name="생성일") updated_at = models.DateTimeField(auto_now=True, verbose_name="수정일") def __str__(self): - return f"{self.lecture.title} - {self.title}" + return f"{self.id} - {self.title}" class Meta: ordering = ["order"] @@ -136,6 +209,10 @@ class Meta: class MultipleChoiceQuestion(models.Model): + """ + 객관식 문제 모델입니다. + """ + topic = models.OneToOneField( Topic, on_delete=models.CASCADE, @@ -155,6 +232,10 @@ class Meta: class MultipleChoiceQuestionChoice(models.Model): + """ + 객관식 문제 선택지 모델입니다. + """ + question = models.ForeignKey( MultipleChoiceQuestion, on_delete=models.CASCADE, @@ -171,6 +252,10 @@ def __str__(self): class Assignment(models.Model): + """ + 과제 모델입니다. + """ + topic = models.OneToOneField( Topic, on_delete=models.CASCADE, related_name="assignment", verbose_name="주제" ) diff --git a/courses/serializers.py b/courses/serializers.py index 7f01266..1102366 100644 --- a/courses/serializers.py +++ b/courses/serializers.py @@ -16,11 +16,10 @@ class AssignmentSerializer(serializers.ModelSerializer): Assignment 모델을 위한 Serializer입니다 """ - id = serializers.IntegerField(read_only=True) - class Meta: model = Assignment - fields = ["id", "question", "created_at", "updated_at"] + fields = ["question", "created_at", "updated_at", "id"] + read_only_fields = ["created_at", "updated_at", "id"] class MultipleChoiceQuestionChoiceSerializer(serializers.ModelSerializer): @@ -28,11 +27,10 @@ class MultipleChoiceQuestionChoiceSerializer(serializers.ModelSerializer): MultipleChoiceQuestionChoice 모델을 위한 Serializer입니다 """ - id = serializers.IntegerField(read_only=True) - class Meta: model = MultipleChoiceQuestionChoice - fields = ["id", "choice", "is_correct", "created_at", "updated_at"] + fields = ["id", "choice", "is_correct", "created_at", "updated_at", "id"] + read_only_fields = ["created_at", "updated_at", "id"] class MultipleChoiceQuestionSerializer(serializers.ModelSerializer): @@ -41,7 +39,6 @@ class MultipleChoiceQuestionSerializer(serializers.ModelSerializer): """ multiple_choice_question_choices = MultipleChoiceQuestionChoiceSerializer(many=True) - id = serializers.IntegerField(read_only=True) class Meta: model = MultipleChoiceQuestion @@ -52,6 +49,7 @@ class Meta: "updated_at", "multiple_choice_question_choices", ] + read_only_fields = ["created_at", "updated_at", "id"] class TopicSerializer(serializers.ModelSerializer): @@ -59,9 +57,11 @@ class TopicSerializer(serializers.ModelSerializer): Topic 모델을 위한 Serializer입니다 """ - id = serializers.IntegerField(read_only=True) multiple_choice_question = MultipleChoiceQuestionSerializer(required=False) assignment = AssignmentSerializer(required=False) + video_url = serializers.SerializerMethodField() + video_id = serializers.IntegerField(write_only=True, required=False) + video_duration = serializers.SerializerMethodField() class Meta: model = Topic @@ -69,27 +69,46 @@ class Meta: "id", "title", "type", - "description", "order", "is_premium", "created_at", "updated_at", "multiple_choice_question", "assignment", + "video_url", + "video_id", + "video_duration", + ] + read_only_fields = [ + "created_at", + "updated_at", + "id", + "video_url", + "video_duration", ] + def get_video_url(self, obj): + if getattr(obj, "video", None): + return obj.video.video_url + return None + + def get_video_duration(self, obj): + if getattr(obj, "video", None): + return 0 + return None + class LectureSerializer(serializers.ModelSerializer): """ Lecture 모델을 위한 Serializer입니다 """ - id = serializers.IntegerField(read_only=True) topics = TopicSerializer(many=True) class Meta: model = Lecture fields = ["id", "title", "order", "created_at", "updated_at", "topics"] + read_only_fields = ["created_at", "updated_at", "id"] class CourseDetailSerializer(serializers.ModelSerializer): @@ -97,11 +116,18 @@ class CourseDetailSerializer(serializers.ModelSerializer): Course 모델을 위한 Serializer입니다 """ - id = serializers.IntegerField(read_only=True) + video_id = serializers.IntegerField(write_only=True) + thumbnail_id = serializers.IntegerField(write_only=True) lectures = LectureSerializer( many=True, required=False, ) + video_url = serializers.SerializerMethodField() + thumbnail_url = serializers.SerializerMethodField() + author_image = serializers.SerializerMethodField() + author_name = serializers.SerializerMethodField() + author_id = serializers.SerializerMethodField() + author_introduction = serializers.SerializerMethodField() class Meta: model = Course @@ -114,9 +140,55 @@ class Meta: "created_at", "updated_at", "lectures", - "course_level", + "skill_level", "price", + "thumbnail_id", + "video_id", + "video_url", + "thumbnail_url", + "author_image", + "author_name", + "author_id", + "author_introduction", ] + read_only_fields = [ + "created_at", + "updated_at", + "id", + "video_url", + "thumbnail_url", + "author_image", + "author_name", + "author_id", + "author_introduction", + ] + + def get_author_image(self, obj): + print(obj.author.image.image_url) + if getattr(obj.author, "image", None): + return obj.author.image.image_url + return None + + def get_author_name(self, obj): + return obj.author.nickname + + def get_video_url(self, obj): + if getattr(obj, "video", None): + return obj.video.video_url + return None + + def get_thumbnail_url(self, obj): + if getattr(obj, "thumbnail", None): + return obj.thumbnail.url + return None + + def get_author_id(self, obj): + return obj.author.id + + def get_author_introduction(self, obj): + return ( + obj.author.introduction if obj.author.introduction else "소개가 없습니다." + ) class CourseSummarySerializer(serializers.ModelSerializer): @@ -124,7 +196,10 @@ class CourseSummarySerializer(serializers.ModelSerializer): Course 모델을 위한 Serializer입니다 """ - id = serializers.IntegerField(read_only=True) + lectures_count = serializers.SerializerMethodField() + thumbnail = serializers.SerializerMethodField() + author_image = serializers.SerializerMethodField() + author_name = serializers.SerializerMethodField() class Meta: model = Course @@ -135,8 +210,33 @@ class Meta: "category", "created_at", "updated_at", - "course_level", + "skill_level", + "lectures_count", + "thumbnail", + "author_image", + "author_name", ] + read_only_fields = [ + "created_at", + "updated_at", + "id", + "lectures_count", + "thumbnail", + "author_image", + "author_name", + ] + + def get_lectures_count(self, obj): + return obj.lectures.count() + + def get_thumbnail(self, obj): + return obj.get_thumbnail() + + def get_author_image(self, obj): + return "https://paullab.co.kr/images/weniv-licat.png" + + def get_author_name(self, obj): + return obj.author.nickname class CurriculumReadSerializer(serializers.ModelSerializer): @@ -144,7 +244,6 @@ class CurriculumReadSerializer(serializers.ModelSerializer): Curriculum 모델을 조회하기 위한 Serializer입니다. 직렬화 할 때만 사용합니다. """ - id = serializers.IntegerField(read_only=True) courses = CourseSummarySerializer( many=True, ) @@ -160,6 +259,7 @@ class Meta: "created_at", "updated_at", ] + read_only_fields = ["created_at", "updated_at", "id"] class CurriculumCreateAndUpdateSerializer(serializers.ModelSerializer): @@ -167,7 +267,6 @@ class CurriculumCreateAndUpdateSerializer(serializers.ModelSerializer): Curriculum 모델을 생성 및 수정을 위한 Serializer입니다. 역직렬화 할 때만 사용합니다. """ - id = serializers.IntegerField(required=False) courses_ids = serializers.ListField( child=serializers.IntegerField(), write_only=True ) @@ -175,6 +274,7 @@ class CurriculumCreateAndUpdateSerializer(serializers.ModelSerializer): class Meta: model = Curriculum fields = ["id", "name", "price", "description", "courses_ids"] + read_only_fields = ["created_at", "updated_at", "id"] class CurriculumSummarySerializer(serializers.ModelSerializer): @@ -182,7 +282,9 @@ class CurriculumSummarySerializer(serializers.ModelSerializer): Curriculum 모델을 위한 Serializer입니다. 직렬화 할 때만 사용합니다. """ - id = serializers.IntegerField(read_only=True) + author_image = serializers.SerializerMethodField() + author_name = serializers.SerializerMethodField() + courses_count = serializers.SerializerMethodField() class Meta: model = Curriculum @@ -192,4 +294,30 @@ class Meta: "price", "created_at", "updated_at", + "author_image", + "author_name", + "category", + "skill_level", + "description", + "courses_count", ] + read_only_fields = [ + "created_at", + "updated_at", + "id", + "author_image", + "author_name", + "category", + "skill_level", + "description", + "courses_count", + ] + + def get_author_image(self, obj): + return "https://paullab.co.kr/images/weniv-licat.png" + + def get_author_name(self, obj): + return obj.author.nickname + + def get_courses_count(self, obj): + return obj.courses.count() diff --git a/courses/test/conftest.py b/courses/test/conftest.py index 9864677..728d30a 100644 --- a/courses/test/conftest.py +++ b/courses/test/conftest.py @@ -10,13 +10,15 @@ MultipleChoiceQuestionChoice, Topic, ) +from jwtauth.utils.token_generator import generate_access_token +from materials.models import Image, Video # 테스트에서 사용할 상수를 정의합니다. COURSE_TITLE = "Test Course" COURSE_SHORT_DESCRIPTION = "Test Course" COURSE_DESCRIPTION = {} COURSE_CATEGORY = "JavaScript" -COURSE_COURSE_LEVEL = "beginner" +COURSE_SKILL_LEVEL = "beginner" COURSE_PRICE = 10000 LECTURE1_TITLE = "Test Lecture 1" LECTURE1_ORDER = 1 @@ -45,17 +47,18 @@ @pytest.fixture -def setup_course_data(): +def setup_course_data(create_staff_user): """ 테스트에서 사용할 Course, Lecture, Topic, Assignment, MultipleChoiceQuestion, MultipleChoiceQuestionChoice 인스턴스를 생성합니다. """ course = Course.objects.create( title=COURSE_TITLE, + author=create_staff_user, short_description=COURSE_SHORT_DESCRIPTION, description=COURSE_DESCRIPTION, category=COURSE_CATEGORY, - course_level=COURSE_COURSE_LEVEL, + skill_level=COURSE_SKILL_LEVEL, price=COURSE_PRICE, ) lecture1 = Lecture.objects.create( @@ -70,7 +73,6 @@ def setup_course_data(): title=TOPIC1_TITLE, lecture=lecture1, type=TOPIC1_TYPE, - description=TOPIC1_DESCRIPTION, order=1, is_premium=True, ) @@ -78,7 +80,6 @@ def setup_course_data(): title=TOPIC2_TITLE, lecture=lecture2, type=TOPIC2_TYPE, - description=TOPIC2_DESCRIPTION, order=TOPIC2_ORDER, is_premium=True, ) @@ -132,11 +133,37 @@ def api_client(): @pytest.fixture(autouse=True) def create_user(): - return User.objects.create_user(email=TEST_USER_EMAIL, password=TEST_USER_PASSWORD) + user = User.objects.create_user( + email=TEST_USER_EMAIL, password=TEST_USER_PASSWORD, nickname="testuser" + ) + Image.objects.create(user=user, image_url="test.jpg") + return user @pytest.fixture(autouse=True) def create_staff_user(): - return User.objects.create_user( - email=TEST_STAFF_USER_EMAIL, password=TEST_STAFF_USER_PASSWORD, is_staff=True + user = User.objects.create_user( + email=TEST_STAFF_USER_EMAIL, + password=TEST_STAFF_USER_PASSWORD, + is_staff=True, + nickname="staffuser", + ) + Image.objects.create(user=user, image_url="test.jpg") + return user + + +@pytest.fixture() +def user_token(create_user): + return generate_access_token(create_user) + + +@pytest.fixture() +def staff_user_token(create_staff_user): + return generate_access_token(create_staff_user) + + +@pytest.fixture +def create_video(): + return Video.objects.create( + video_url="https://www.youtube.com/watch?v=123456", ) diff --git a/courses/test/test_mixins.py b/courses/test/test_mixins.py index 0d06488..37d736b 100644 --- a/courses/test/test_mixins.py +++ b/courses/test/test_mixins.py @@ -7,7 +7,7 @@ @pytest.mark.django_db class TestCourseMixin: - def test_create_course_with_lectures_and_topics(self): + def test_create_course_with_lectures_and_topics(self, create_staff_user): # Given course_mixin = CourseMixin() course_data = { @@ -15,7 +15,7 @@ def test_create_course_with_lectures_and_topics(self): "short_description": "course_short_description", "description": "course_description", "category": "JavaScript", - "course_level": "beginner", + "skill_level": "beginner", "price": 10000, } lectures_data = [ @@ -87,7 +87,7 @@ def test_create_course_with_lectures_and_topics(self): # When course = course_mixin.create_course_with_lectures_and_topics( - course_data, lectures_data + course_data, lectures_data, create_staff_user ) # Then @@ -104,7 +104,7 @@ def test_create_course_with_lectures_and_topics(self): assert course.short_description == course_data["short_description"] assert course.description == course_data["description"] assert course.category == course_data["category"] - assert course.course_level == course_data["course_level"] + assert course.skill_level == course_data["skill_level"] assert course.price == course_data["price"] lectures = course.lectures.all() @@ -128,7 +128,7 @@ def test_create_course_with_lectures_and_topics(self): == 4 ) - def test_create_course(self): + def test_create_course(self, create_staff_user): # Given course_mixin = CourseMixin() course_data = { @@ -136,12 +136,12 @@ def test_create_course(self): "short_description": "course_short_description", "description": "course_description", "category": "JavaScript", - "course_level": "beginner", + "skill_level": "beginner", "price": 10000, } # When - course = course_mixin._create_course(course_data) + course = course_mixin._create_course(course_data, create_staff_user) # Then assert course is not None @@ -149,9 +149,9 @@ def test_create_course(self): assert course.short_description == course_data["short_description"] assert course.description == course_data["description"] assert course.category == course_data["category"] - assert course.course_level == course_data["course_level"] + assert course.skill_level == course_data["skill_level"] - def test_create_lecture(self): + def test_create_lecture(self, create_staff_user): # Given course_mixin = CourseMixin() lecture_data = { @@ -161,10 +161,11 @@ def test_create_lecture(self): course = Course.objects.create( title="course_title", category="JavaScript", - course_level="beginner", + skill_level="beginner", short_description="course_short_description", description="course_description", price=10000, + author=create_staff_user, ) # When @@ -176,23 +177,23 @@ def test_create_lecture(self): assert lecture.order == lecture_data["order"] assert lecture.course == course - def test_create_topic(self): + def test_create_topic(self, create_staff_user): # Given course_mixin = CourseMixin() topic_data = { "title": "topic_title", "type": "assignment", - "description": "topic_description", "order": 1, "is_premium": True, } lecture = Course.objects.create( title="course_title", category="JavaScript", - course_level="beginner", + skill_level="beginner", short_description="course_short_description", description="course_description", price=10000, + author=create_staff_user, ).lectures.create(title="lecture_title", order=1) # When @@ -202,12 +203,11 @@ def test_create_topic(self): assert topic is not None assert topic.title == topic_data["title"] assert topic.type == topic_data["type"] - assert topic.description == topic_data["description"] assert topic.order == topic_data["order"] assert topic.is_premium == topic_data["is_premium"] assert topic.lecture == lecture - def test_handle_topic_type_assignment(self): + def test_handle_topic_type_assignment(self, create_staff_user): # Given course_mixin = CourseMixin() topic_data = { @@ -223,10 +223,11 @@ def test_handle_topic_type_assignment(self): lecture = Course.objects.create( title="course_title", category="JavaScript", - course_level="beginner", + skill_level="beginner", short_description="course_short_description", description="course_description", price=10000, + author=create_staff_user, ).lectures.create(title="lecture_title", order=1) # When @@ -237,7 +238,7 @@ def test_handle_topic_type_assignment(self): assert topic.assignment is not None assert topic.assignment.question == topic_data["assignment"]["question"] - def test_handle_topic_type_quiz(self): + def test_handle_topic_type_quiz(self, create_staff_user): # Given course_mixin = CourseMixin() topic_data = { @@ -259,10 +260,11 @@ def test_handle_topic_type_quiz(self): lecture = Course.objects.create( title="course_title", category="JavaScript", - course_level="beginner", + skill_level="beginner", short_description="course_short_description", description="course_description", price=10000, + author=create_staff_user, ).lectures.create(title="lecture_title", order=1) # When @@ -279,7 +281,7 @@ def test_handle_topic_type_quiz(self): topic.multiple_choice_question.multiple_choice_question_choices.count() == 4 ) - def test_create_assignment(self): + def test_create_assignment(self, create_staff_user): # Given course_mixin = CourseMixin() assignment_data = { @@ -289,16 +291,16 @@ def test_create_assignment(self): Course.objects.create( title="course_title", category="JavaScript", - course_level="beginner", + skill_level="beginner", short_description="course_short_description", description="course_description", price=10000, + author=create_staff_user, ) .lectures.create(title="lecture_title", order=1) .topics.create( title="topic_title", type="assignment", - description="topic_description", order=1, is_premium=True, ) @@ -311,7 +313,7 @@ def test_create_assignment(self): assert topic.assignment is not None assert topic.assignment.question == assignment_data["question"] - def test_create_quiz(self): + def test_create_quiz(self, create_staff_user): # Given course_mixin = CourseMixin() multiple_choice_question_data = { @@ -327,16 +329,16 @@ def test_create_quiz(self): Course.objects.create( title="course_title", category="JavaScript", - course_level="beginner", + skill_level="beginner", short_description="course_short_description", description="course_description", price=10000, + author=create_staff_user, ) .lectures.create(title="lecture_title", order=1) .topics.create( title="topic_title", type="quiz", - description="topic_description", order=1, is_premium=True, ) @@ -361,23 +363,23 @@ def test_create_quiz(self): == 1 ) - def test_create_multiple_choice_question_choice(self): + def test_create_multiple_choice_question_choice(self, create_staff_user): # Given course_mixin = CourseMixin() course = Course.objects.create( title="course_title", category="JavaScript", - course_level="beginner", + skill_level="beginner", short_description="course_short_description", description="course_description", price=10000, + author=create_staff_user, ) lecture = Lecture.objects.create(title="lecture_title", course=course, order=1) topic = Topic.objects.create( title="topic_title", lecture=lecture, type="quiz", - description="topic_description", order=1, is_premium=True, ) @@ -403,23 +405,23 @@ def test_create_multiple_choice_question_choice(self): == 1 ) - def test_create_multiple_choice_question(self): + def test_create_multiple_choice_question(self, create_staff_user): # Given course_mixin = CourseMixin() course = Course.objects.create( title="course_title", category="JavaScript", - course_level="beginner", + skill_level="beginner", short_description="course_short_description", description="course_description", price=10000, + author=create_staff_user, ) lecture = Lecture.objects.create(title="lecture_title", course=course, order=1) topic = Topic.objects.create( title="topic_title", lecture=lecture, type="quiz", - description="topic_description", order=1, is_premium=True, ) @@ -452,23 +454,24 @@ def test_create_multiple_choice_question(self): == 1 ) - def test_update_course(self): + def test_update_course(self, create_staff_user): # Given course_mixin = CourseMixin() course = Course.objects.create( title="course_title", category="JavaScript", - course_level="beginner", + skill_level="beginner", short_description="course_short_description", description="course_description", price=10000, + author=create_staff_user, ) course_data = { "title": "updated_course_title", "short_description": "updated_course_short_description", "description": "updated_course_description", "category": "Python", - "course_level": "intermediate", + "skill_level": "intermediate", "price": 20000, } lectures_data = [ @@ -517,7 +520,7 @@ def test_update_course(self): assert course.short_description == course_data["short_description"] assert course.description == course_data["description"] assert course.category == course_data["category"] - assert course.course_level == course_data["course_level"] + assert course.skill_level == course_data["skill_level"] assert course.price == course_data["price"] assert course.lectures.count() == 1 assert course.lectures.first().topics.count() == 2 diff --git a/courses/test/test_models.py b/courses/test/test_models.py index 60e2268..1a833da 100644 --- a/courses/test/test_models.py +++ b/courses/test/test_models.py @@ -5,10 +5,13 @@ @pytest.mark.django_db -def test_course_생성(): +def test_course_생성(create_staff_user): # Given curriculum = Curriculum.objects.create( - name="Test Curriculum", description="Test Description", price=1000 + name="Test Curriculum", + description="Test Description", + price=1000, + author=create_staff_user, ) # When @@ -18,8 +21,9 @@ def test_course_생성(): short_description="Short Description", description={"content": "Detailed Description"}, category="Python", - course_level="beginner", + skill_level="beginner", price=500, + author=create_staff_user, ) # Then @@ -27,7 +31,7 @@ def test_course_생성(): assert course.short_description == "Short Description" assert course.description == {"content": "Detailed Description"} assert course.category == "Python" - assert course.course_level == "beginner" + assert course.skill_level == "beginner" assert course.price == 500 assert course.curriculum == curriculum assert course.created_at <= timezone.now() @@ -35,10 +39,13 @@ def test_course_생성(): @pytest.mark.django_db -def test_course_업데이트(): +def test_course_업데이트(create_staff_user): # Given curriculum = Curriculum.objects.create( - name="Test Curriculum", description="Test Description", price=1000 + name="Test Curriculum", + description="Test Description", + price=1000, + author=create_staff_user, ) course = Course.objects.create( curriculum=curriculum, @@ -46,8 +53,9 @@ def test_course_업데이트(): short_description="Short Description", description={"content": "Detailed Description"}, category="Python", - course_level="beginner", + skill_level="beginner", price=500, + author=create_staff_user, ) # When @@ -56,7 +64,7 @@ def test_course_업데이트(): short_description="Updated Short Description", description={"content": "Updated Detailed Description"}, category="Django", - course_level="intermediate", + skill_level="intermediate", price=700, ) @@ -66,15 +74,18 @@ def test_course_업데이트(): assert updated_course.short_description == "Updated Short Description" assert updated_course.description == {"content": "Updated Detailed Description"} assert updated_course.category == "Django" - assert updated_course.course_level == "intermediate" + assert updated_course.skill_level == "intermediate" assert updated_course.price == 700 @pytest.mark.django_db -def test_course_업데이트_course의_없는_필드는_무시된다(): +def test_course_업데이트_course의_없는_필드는_무시된다(create_staff_user): # Given curriculum = Curriculum.objects.create( - name="Test Curriculum", description="Test Description", price=1000 + name="Test Curriculum", + description="Test Description", + price=1000, + author=create_staff_user, ) course = Course.objects.create( curriculum=curriculum, @@ -82,8 +93,9 @@ def test_course_업데이트_course의_없는_필드는_무시된다(): short_description="Short Description", description={"content": "Detailed Description"}, category="Python", - course_level="beginner", + skill_level="beginner", price=500, + author=create_staff_user, ) # When @@ -95,15 +107,18 @@ def test_course_업데이트_course의_없는_필드는_무시된다(): assert updated_course.short_description == "Short Description" assert updated_course.description == {"content": "Detailed Description"} assert updated_course.category == "Python" - assert updated_course.course_level == "beginner" + assert updated_course.skill_level == "beginner" assert updated_course.price == 500 @pytest.mark.django_db -def test_course_업데이트_특정_필드만_업데이트(): +def test_course_업데이트_특정_필드만_업데이트(create_staff_user): # Given curriculum = Curriculum.objects.create( - name="Test Curriculum", description="Test Description", price=1000 + name="Test Curriculum", + description="Test Description", + price=1000, + author=create_staff_user, ) course = Course.objects.create( curriculum=curriculum, @@ -111,8 +126,9 @@ def test_course_업데이트_특정_필드만_업데이트(): short_description="Short Description", description={"content": "Detailed Description"}, category="Python", - course_level="beginner", + skill_level="beginner", price=500, + author=create_staff_user, ) # When @@ -124,6 +140,6 @@ def test_course_업데이트_특정_필드만_업데이트(): assert updated_course.short_description == "Short Description" assert updated_course.description == {"content": "Detailed Description"} assert updated_course.category == "Python" - assert updated_course.course_level == "beginner" + assert updated_course.skill_level == "beginner" assert updated_course.price == 500 assert updated_course.lectures.count() == 0 diff --git a/courses/test/test_serializers.py b/courses/test/test_serializers.py index e6950ea..c723ee4 100644 --- a/courses/test/test_serializers.py +++ b/courses/test/test_serializers.py @@ -34,9 +34,8 @@ def test_course_직렬화(self, setup_course_data): assert data["id"] == self.course.id assert data["title"] == conftest.COURSE_TITLE assert data["short_description"] == conftest.COURSE_SHORT_DESCRIPTION - assert data["description"] == conftest.COURSE_DESCRIPTION assert data["category"] == conftest.COURSE_CATEGORY - assert data["course_level"] == conftest.COURSE_COURSE_LEVEL + assert data["skill_level"] == conftest.COURSE_SKILL_LEVEL assert data["price"] == conftest.COURSE_PRICE assert len(data["lectures"]) == 2 assert data["lectures"][0]["title"] == conftest.LECTURE1_TITLE @@ -46,10 +45,6 @@ def test_course_직렬화(self, setup_course_data): assert len(data["lectures"][0]["topics"]) == 1 assert data["lectures"][0]["topics"][0]["title"] == conftest.TOPIC1_TITLE assert data["lectures"][0]["topics"][0]["type"] == conftest.TOPIC1_TYPE - assert ( - data["lectures"][0]["topics"][0]["description"] - == conftest.TOPIC1_DESCRIPTION - ) assert data["lectures"][0]["topics"][0]["order"] == conftest.TOPIC1_ORDER assert data["lectures"][0]["topics"][0]["is_premium"] is True assert ( @@ -58,10 +53,6 @@ def test_course_직렬화(self, setup_course_data): ) assert data["lectures"][1]["topics"][0]["title"] == conftest.TOPIC2_TITLE assert data["lectures"][1]["topics"][0]["type"] == conftest.TOPIC2_TYPE - assert ( - data["lectures"][1]["topics"][0]["description"] - == conftest.TOPIC2_DESCRIPTION - ) assert data["lectures"][1]["topics"][0]["order"] == conftest.TOPIC2_ORDER assert data["lectures"][1]["topics"][0]["is_premium"] is True assert ( @@ -132,8 +123,10 @@ def test_course_역직렬화(self): "short_description": "Test Course", "description": {}, "category": "JavaScript", - "course_level": "beginner", + "skill_level": "beginner", "price": 10000, + "thumbnail_id": 1, + "video_id": 3, "lectures": [ { "title": "Test Lecture", @@ -146,6 +139,7 @@ def test_course_역직렬화(self): "order": 1, "is_premium": True, "assignment": {"question": "Test Assignment"}, + "video_id": 1, } ], }, @@ -156,7 +150,6 @@ def test_course_역직렬화(self): { "title": "Test Topic 2", "type": "assignment", - "description": "Test Description", "order": 1, "is_premium": True, "multiple_choice_question": { @@ -168,6 +161,7 @@ def test_course_역직렬화(self): {"choice": "Choice 4", "is_correct": False}, ], }, + "video_id": 2, } ], }, @@ -179,12 +173,22 @@ def test_course_역직렬화(self): serializer.is_valid(raise_exception=True) # Then - assert serializer.validated_data == data + assert serializer.validated_data["title"] == "Test Course" + assert serializer.validated_data["short_description"] == "Test Course" + assert serializer.validated_data["category"] == "JavaScript" + assert serializer.validated_data["skill_level"] == "beginner" + assert serializer.validated_data["price"] == 10000 + assert serializer.validated_data["lectures"][0]["title"] == "Test Lecture" + assert serializer.validated_data["lectures"][0]["order"] == 1 + assert ( + serializer.validated_data["lectures"][0]["topics"][0]["title"] + == "Test Topic" + ) @pytest.mark.django_db class TestCourseSummarySerializer: - def test_course_직렬화(self, setup_course_data): + def test_course_summary_직렬화(self, setup_course_data): # Given course = setup_course_data["course"] @@ -197,15 +201,16 @@ def test_course_직렬화(self, setup_course_data): assert data["title"] == conftest.COURSE_TITLE assert data["short_description"] == conftest.COURSE_SHORT_DESCRIPTION assert data["category"] == conftest.COURSE_CATEGORY - assert data["course_level"] == conftest.COURSE_COURSE_LEVEL + assert data["skill_level"] == conftest.COURSE_SKILL_LEVEL + assert data["lectures_count"] == 2 - def test_course_역직렬화(self): + def test_course_summary_역직렬화(self): # Given data = { "title": "Test Course", "short_description": "Test Course", "category": "JavaScript", - "course_level": "beginner", + "skill_level": "beginner", } # When @@ -234,7 +239,6 @@ def test_lecture_직렬화(self, setup_course_data): assert len(data["topics"]) == 1 assert data["topics"][0]["title"] == conftest.TOPIC2_TITLE assert data["topics"][0]["type"] == conftest.TOPIC2_TYPE - assert data["topics"][0]["description"] == conftest.TOPIC2_DESCRIPTION assert data["topics"][0]["order"] == conftest.TOPIC2_ORDER assert data["topics"][0]["is_premium"] is True assert ( @@ -298,39 +302,6 @@ def test_lecture_직렬화(self, setup_course_data): is False ) - def test_lecture_역직렬화(self): - # Given - data = { - "title": "Test Lecture", - "order": 1, - "topics": [ - { - "title": "Test Topic", - "type": "quiz", - "description": "Test Description", - "order": 1, - "is_premium": True, - "assignment": {"question": "Test Assignment"}, - "multiple_choice_question": { - "question": "Test Multiple Choice Question", - "multiple_choice_question_choices": [ - {"choice": "Choice 1", "is_correct": True}, - {"choice": "Choice 2", "is_correct": False}, - {"choice": "Choice 3", "is_correct": False}, - {"choice": "Choice 4", "is_correct": False}, - ], - }, - } - ], - } - - # When - serializer = LectureSerializer(data=data) - serializer.is_valid(raise_exception=True) - - # Then - assert serializer.validated_data == data - @pytest.mark.django_db class TestTopicSerializer: @@ -347,7 +318,6 @@ def test_topic_직렬화(self, setup_course_data): assert data["id"] == self.topic.id assert data["title"] == conftest.TOPIC2_TITLE assert data["type"] == conftest.TOPIC2_TYPE - assert data["description"] == conftest.TOPIC2_DESCRIPTION assert data["order"] == conftest.TOPIC2_ORDER assert data["is_premium"] is True assert data["multiple_choice_question"]["question"] == conftest.MCQ_QUESTION @@ -404,33 +374,6 @@ def test_topic_직렬화(self, setup_course_data): is False ) - def test_topic_역직렬화(self): - # Given - data = { - "title": "Test Topic", - "type": "quiz", - "description": "Test Description", - "order": 1, - "is_premium": True, - "assignment": {"question": "Test Assignment"}, - "multiple_choice_question": { - "question": "Test Multiple Choice Question", - "multiple_choice_question_choices": [ - {"choice": "Choice 1", "is_correct": True}, - {"choice": "Choice 2", "is_correct": False}, - {"choice": "Choice 3", "is_correct": False}, - {"choice": "Choice 4", "is_correct": False}, - ], - }, - } - - # When - serializer = TopicSerializer(data=data) - serializer.is_valid(raise_exception=True) - - # Then - assert serializer.validated_data == data - @pytest.mark.django_db class TestAssignmentSerializer: @@ -600,10 +543,13 @@ def test_curriculum_create_and_update_역직렬화(self): @pytest.mark.django_db class TestCurriculumReadSerializer: - def test_curriculum_직렬화(self, setup_course_data): + def test_curriculum_직렬화(self, setup_course_data, create_staff_user): # Given curriculum = Curriculum.objects.create( - name="Test Curriculum", description="Test Description", price=1000 + name="Test Curriculum", + description="Test Description", + price=1000, + author=create_staff_user, ) course = setup_course_data["course"] curriculum.courses.add(course) @@ -618,13 +564,17 @@ def test_curriculum_직렬화(self, setup_course_data): assert serializer.data["price"] == 1000 assert serializer.data["courses"] is not None + @pytest.mark.django_db class TestCurriculumSummarySerializer: - def test_curriculum_직렬화(self): + def test_curriculum_summary_직렬화(self, create_staff_user): # Given curriculum = Curriculum.objects.create( - name="Test Curriculum", description="Test Description", price=1000 + name="Test Curriculum", + description="Test Description", + price=1000, + author=create_staff_user, ) # When @@ -635,4 +585,4 @@ def test_curriculum_직렬화(self): assert serializer.data["name"] == "Test Curriculum" assert serializer.data["price"] == 1000 assert serializer.data["created_at"] is not None - assert serializer.data["updated_at"] is not None \ No newline at end of file + assert serializer.data["updated_at"] is not None diff --git a/courses/test/test_views.py b/courses/test/test_views.py index fe0e3a5..5c2681d 100644 --- a/courses/test/test_views.py +++ b/courses/test/test_views.py @@ -21,7 +21,7 @@ def test_course_조회(self, api_client, setup_course_data): assert response.data["title"] == conftest.COURSE_TITLE assert response.data["short_description"] == conftest.COURSE_SHORT_DESCRIPTION assert response.data["category"] == conftest.COURSE_CATEGORY - assert response.data["course_level"] == conftest.COURSE_COURSE_LEVEL + assert response.data["skill_level"] == conftest.COURSE_SKILL_LEVEL assert response.data["created_at"] is not None assert response.data["updated_at"] is not None assert len(response.data["lectures"]) == 2 @@ -44,21 +44,20 @@ def test_course_조회(self, api_client, setup_course_data): is not None ) - def test_course_수정(self, api_client, setup_course_data): + def test_course_수정(self, api_client, setup_course_data, staff_user_token): # Given course = setup_course_data["course"] url = reverse("courses:course-detail", args=[course.id]) - api_client.login( - username=conftest.TEST_STAFF_USER_EMAIL, - password=conftest.TEST_STAFF_USER_PASSWORD, - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {staff_user_token}") data = { "title": "Updated Test Course", "short_description": "Updated Test Course", "description": {}, "category": "Python", - "course_level": "intermediate", + "skill_level": "intermediate", "price": 20000, + "video_id": 1, + "thumbnail_id": 1, "lectures": [ { "title": "Updated Test Lecture", @@ -66,11 +65,11 @@ def test_course_수정(self, api_client, setup_course_data): "topics": [ { "title": "Updated Test Topic", - "type": "assignment", - "description": "Updated Test Description", + "type": "video", "order": 1, "is_premium": True, "assignment": {"question": "Updated Test Assignment"}, + "video_id": 1, } ], }, @@ -93,6 +92,7 @@ def test_course_수정(self, api_client, setup_course_data): {"choice": "Updated Choice 4", "is_correct": False}, ], }, + "video_id": 2, } ], }, @@ -107,7 +107,7 @@ def test_course_수정(self, api_client, setup_course_data): assert response.data["title"] == "Updated Test Course" assert response.data["short_description"] == "Updated Test Course" assert response.data["category"] == "Python" - assert response.data["course_level"] == "intermediate" + assert response.data["skill_level"] == "intermediate" assert response.data["created_at"] is not None assert response.data["updated_at"] is not None assert len(response.data["lectures"]) == 2 @@ -115,10 +115,6 @@ def test_course_수정(self, api_client, setup_course_data): assert response.data["lectures"][1]["title"] == "Updated Test Lecture 2" assert len(response.data["lectures"][0]["topics"]) == 1 assert len(response.data["lectures"][1]["topics"]) == 1 - assert ( - response.data["lectures"][0]["topics"][0]["assignment"]["question"] - == "Updated Test Assignment" - ) assert ( response.data["lectures"][1]["topics"][0]["multiple_choice_question"][ "question" @@ -132,15 +128,16 @@ def test_course_수정(self, api_client, setup_course_data): == "Updated Choice 1" ) - def test_course_수정_실패_로그인하지않은경우(self, api_client): + def test_course_수정_실패_로그인하지않은경우(self, api_client, create_staff_user): # Given course = Course.objects.create( title="Test Course", short_description="Test Course", description={}, category="JavaScript", - course_level="beginner", + skill_level="beginner", price=10000, + author=create_staff_user, ) url = reverse("courses:course-detail", args=[course.id]) data = { @@ -148,7 +145,7 @@ def test_course_수정_실패_로그인하지않은경우(self, api_client): "short_description": "Updated Test Course", "description": {}, "category": "Python", - "course_level": "intermediate", + "skill_level": "intermediate", "price": 20000, } @@ -161,26 +158,27 @@ def test_course_수정_실패_로그인하지않은경우(self, api_client): "detail": "자격 인증데이터(authentication credentials)가 제공되지 않았습니다." } - def test_course_수정_실패_일반유저인_경우(self, api_client): + def test_course_수정_실패_일반유저인_경우( + self, api_client, user_token, create_staff_user + ): # Given course = Course.objects.create( title="Test Course", short_description="Test Course", description={}, category="JavaScript", - course_level="beginner", + skill_level="beginner", price=10000, + author=create_staff_user, ) url = reverse("courses:course-detail", args=[course.id]) - api_client.login( - username=conftest.TEST_USER_EMAIL, password=conftest.TEST_USER_PASSWORD - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {user_token}") data = { "title": "Updated Test Course", "short_description": "Updated Test Course", "description": {}, "category": "Python", - "course_level": "intermediate", + "skill_level": "intermediate", "price": 20000, } @@ -193,21 +191,19 @@ def test_course_수정_실패_일반유저인_경우(self, api_client): "detail": "이 작업을 수행할 권한(permission)이 없습니다." } - def test_course_삭제(self, api_client): + def test_course_삭제(self, api_client, staff_user_token, create_staff_user): # Given course = Course.objects.create( title="Test Course", short_description="Test Course", description={}, category="JavaScript", - course_level="beginner", + skill_level="beginner", price=10000, + author=create_staff_user, ) url = reverse("courses:course-detail", args=[course.id]) - api_client.login( - username=conftest.TEST_STAFF_USER_EMAIL, - password=conftest.TEST_STAFF_USER_PASSWORD, - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {staff_user_token}") # When response = api_client.delete(url) @@ -216,15 +212,16 @@ def test_course_삭제(self, api_client): assert response.status_code == status.HTTP_204_NO_CONTENT assert Course.objects.count() == 0 - def test_course_삭제_실패_로그인하지않은경우(self, api_client): + def test_course_삭제_실패_로그인하지않은경우(self, api_client, create_staff_user): # Given course = Course.objects.create( title="Test Course", short_description="Test Course", description={}, category="JavaScript", - course_level="beginner", + skill_level="beginner", price=10000, + author=create_staff_user, ) url = reverse("courses:course-detail", args=[course.id]) @@ -237,20 +234,21 @@ def test_course_삭제_실패_로그인하지않은경우(self, api_client): "detail": "자격 인증데이터(authentication credentials)가 제공되지 않았습니다." } - def test_course_삭제_실패_일반유저인_경우(self, api_client, create_user): + def test_course_삭제_실패_일반유저인_경우( + self, api_client, user_token, create_staff_user + ): # Given course = Course.objects.create( title="Test Course", short_description="Test Course", description={}, category="JavaScript", - course_level="beginner", + skill_level="beginner", price=10000, + author=create_staff_user, ) url = reverse("courses:course-detail", args=[course.id]) - api_client.login( - username=conftest.TEST_USER_EMAIL, password=conftest.TEST_USER_PASSWORD - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {user_token}") # When response = api_client.delete(url) @@ -268,8 +266,10 @@ def get_course_data(): "short_description": "Test Course", "description": {}, "category": "JavaScript", - "course_level": "beginner", + "skill_level": "beginner", "price": 10000, + "video_id": 1, + "thumbnail_id": 1, "lectures": [ { "title": "Test Lecture", @@ -277,11 +277,11 @@ def get_course_data(): "topics": [ { "title": "Test Topic", - "type": "assignment", + "type": "video", "description": "Test Description", "order": 1, "is_premium": True, - "assignment": {"question": "Test Assignment"}, + "video_id": 1, } ], }, @@ -314,13 +314,10 @@ def get_course_data(): @pytest.mark.django_db class TestCourseList: - def test_course_생성_요청(self, api_client, create_staff_user): + def test_course_생성_요청(self, api_client, staff_user_token): # Given url = reverse("courses:course-list") - api_client.login( - username=conftest.TEST_STAFF_USER_EMAIL, - password=conftest.TEST_STAFF_USER_PASSWORD, - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {staff_user_token}") data = get_course_data() # When @@ -329,12 +326,10 @@ def test_course_생성_요청(self, api_client, create_staff_user): # Then assert response.status_code == status.HTTP_201_CREATED - def test_course_생성_요청_실패_일반유저인_경우(self, api_client, create_user): + def test_course_생성_요청_실패_일반유저인_경우(self, api_client, user_token): # Given url = reverse("courses:course-list") - api_client.login( - username=conftest.TEST_USER_EMAIL, password=conftest.TEST_USER_PASSWORD - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {user_token}") data = get_course_data() # When @@ -368,8 +363,9 @@ def test_course_목록_조회(self, api_client, create_user): short_description="Test Course", description={}, category="JavaScript", - course_level="beginner", + skill_level="beginner", price=10000, + author=create_user, ) url = reverse("courses:course-list") api_client.login( @@ -381,19 +377,18 @@ def test_course_목록_조회(self, api_client, create_user): # Then assert response.status_code == status.HTTP_200_OK - assert len(response.data) == 5 + assert response.data["count"] == 5 @pytest.mark.django_db class TestCurriculumList: - def test_curriculum_생성_요청(self, api_client, setup_course_data): + def test_curriculum_생성_요청( + self, api_client, setup_course_data, staff_user_token + ): # Given url = reverse("courses:curriculum-list") - api_client.login( - username=conftest.TEST_STAFF_USER_EMAIL, - password=conftest.TEST_STAFF_USER_PASSWORD, - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {staff_user_token}") data = { "name": "Test Curriculum", "description": "Test Description", @@ -408,13 +403,11 @@ def test_curriculum_생성_요청(self, api_client, setup_course_data): assert response.status_code == status.HTTP_201_CREATED def test_curriculum_생성_요청_실패_일반유저인_경우( - self, api_client, setup_course_data + self, api_client, setup_course_data, user_token ): # Given url = reverse("courses:curriculum-list") - api_client.login( - username=conftest.TEST_USER_EMAIL, password=conftest.TEST_USER_PASSWORD - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {user_token}") data = { "name": "Test Curriculum", "description": "Test Description", @@ -452,13 +445,16 @@ def test_curriculum_생성_요청_실패_로그인하지않은경우( "detail": "자격 인증데이터(authentication credentials)가 제공되지 않았습니다." } - def test_curriculum_목록_조회(self, api_client, setup_course_data): + def test_curriculum_목록_조회( + self, api_client, setup_course_data, create_staff_user + ): # Given for i in range(5): Curriculum.objects.create( name=f"Test Curriculum {i}", description="Test Description", price=1000, + author=create_staff_user, ) url = reverse("courses:curriculum-list") api_client.login( @@ -470,18 +466,19 @@ def test_curriculum_목록_조회(self, api_client, setup_course_data): # Then assert response.status_code == status.HTTP_200_OK - assert len(response.data) == 5 + assert response.data["count"] == 5 @pytest.mark.django_db class TestCurriculumDetail: - def test_curriculum_조회(self, api_client, setup_course_data): + def test_curriculum_조회(self, api_client, setup_course_data, create_staff_user): # Given curriculum = Curriculum.objects.create( name="Test Curriculum", description="Test Description", price=1000, + author=create_staff_user, ) url = reverse("courses:curriculum-detail", args=[curriculum.id]) course = setup_course_data["course"] @@ -503,18 +500,18 @@ def test_curriculum_조회(self, api_client, setup_course_data): assert response.data["updated_at"] is not None assert len(response.data["courses"]) == 1 - def test_curriculum_수정(self, api_client, setup_course_data): + def test_curriculum_수정( + self, api_client, setup_course_data, staff_user_token, create_staff_user + ): # Given curriculum = Curriculum.objects.create( name="Test Curriculum", description="Test Description", price=1000, + author=create_staff_user, ) url = reverse("courses:curriculum-detail", args=[curriculum.id]) - api_client.login( - username=conftest.TEST_STAFF_USER_EMAIL, - password=conftest.TEST_STAFF_USER_PASSWORD, - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {staff_user_token}") data = { "name": "Updated Test Curriculum", "description": "Updated Test Description", @@ -531,17 +528,18 @@ def test_curriculum_수정(self, api_client, setup_course_data): assert response.data["description"] == "Updated Test Description" assert response.data["price"] == 2000 - def test_curriculum_수정_실패_일반유저인_경우(self, api_client, setup_course_data): + def test_curriculum_수정_실패_일반유저인_경우( + self, api_client, setup_course_data, user_token, create_staff_user + ): # Given curriculum = Curriculum.objects.create( name="Test Curriculum", description="Test Description", price=1000, + author=create_staff_user, ) url = reverse("courses:curriculum-detail", args=[curriculum.id]) - api_client.login( - username=conftest.TEST_USER_EMAIL, password=conftest.TEST_USER_PASSWORD - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {user_token}") data = { "name": "Updated Test Curriculum", "description": "Updated Test Description", @@ -559,13 +557,14 @@ def test_curriculum_수정_실패_일반유저인_경우(self, api_client, setup } def test_curriculum_수정_실패_로그인하지않은경우( - self, api_client, setup_course_data + self, api_client, setup_course_data, create_staff_user ): # Given curriculum = Curriculum.objects.create( name="Test Curriculum", description="Test Description", price=1000, + author=create_staff_user, ) url = reverse("courses:curriculum-detail", args=[curriculum.id]) data = { @@ -584,18 +583,16 @@ def test_curriculum_수정_실패_로그인하지않은경우( "detail": "자격 인증데이터(authentication credentials)가 제공되지 않았습니다." } - def test_curriculum_삭제(self, api_client): + def test_curriculum_삭제(self, api_client, staff_user_token, create_staff_user): # Given curriculum = Curriculum.objects.create( name="Test Curriculum", description="Test Description", price=1000, + author=create_staff_user, ) url = reverse("courses:curriculum-detail", args=[curriculum.id]) - api_client.login( - username=conftest.TEST_STAFF_USER_EMAIL, - password=conftest.TEST_STAFF_USER_PASSWORD, - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {staff_user_token}") # When response = api_client.delete(url) @@ -604,17 +601,18 @@ def test_curriculum_삭제(self, api_client): assert response.status_code == status.HTTP_204_NO_CONTENT assert Curriculum.objects.count() == 0 - def test_curriculum_삭제_실패_일반유저인_경우(self, api_client): + def test_curriculum_삭제_실패_일반유저인_경우( + self, api_client, user_token, create_staff_user + ): # Given curriculum = Curriculum.objects.create( name="Test Curriculum", description="Test Description", price=1000, + author=create_staff_user, ) url = reverse("courses:curriculum-detail", args=[curriculum.id]) - api_client.login( - username=conftest.TEST_USER_EMAIL, password=conftest.TEST_USER_PASSWORD - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {user_token}") # When response = api_client.delete(url) @@ -625,12 +623,15 @@ def test_curriculum_삭제_실패_일반유저인_경우(self, api_client): "detail": "이 작업을 수행할 권한(permission)이 없습니다." } - def test_curriculum_삭제_실패_로그인하지않은경우(self, api_client): + def test_curriculum_삭제_실패_로그인하지않은경우( + self, api_client, create_staff_user + ): # Given curriculum = Curriculum.objects.create( name="Test Curriculum", description="Test Description", price=1000, + author=create_staff_user, ) url = reverse("courses:curriculum-detail", args=[curriculum.id]) @@ -643,13 +644,14 @@ def test_curriculum_삭제_실패_로그인하지않은경우(self, api_client): "detail": "자격 인증데이터(authentication credentials)가 제공되지 않았습니다." } - def test_curriculum_수정_실패_존재하지않는_curriculum인_경우(self, api_client): + def test_curriculum_수정_실패_존재하지않는_curriculum인_경우( + self, + api_client, + staff_user_token, + ): # Given url = reverse("courses:curriculum-detail", args=[1]) - api_client.login( - username=conftest.TEST_STAFF_USER_EMAIL, - password=conftest.TEST_STAFF_USER_PASSWORD, - ) + api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {staff_user_token}") data = { "name": "Updated Test Curriculum", "description": "Updated Test Description", diff --git a/courses/views.py b/courses/views.py index 14d818e..cf95ae1 100644 --- a/courses/views.py +++ b/courses/views.py @@ -1,6 +1,8 @@ from django.db import transaction +from django_filters.rest_framework import DjangoFilterBackend from drf_spectacular.utils import extend_schema, extend_schema_view -from rest_framework import generics +from rest_framework import filters, generics +from rest_framework.pagination import PageNumberPagination from rest_framework.response import Response from .mixins import CourseMixin @@ -15,6 +17,12 @@ ) +class CourseResultsSetPagination(PageNumberPagination): + page_size = 9 + page_size_query_param = None + max_page_size = 9 + + @extend_schema_view( get=extend_schema( summary="Course를 조회하는 API", @@ -49,6 +57,7 @@ class CourseDetailRetrieveUpdateDestroyView( queryset = Course.objects.prefetch_related( "lectures__topics__multiple_choice_question__multiple_choice_question_choices", "lectures__topics__assignment", + "author", ) serializer_class = CourseDetailSerializer permission_classes = [IsStaffOrReadOnly] @@ -63,6 +72,7 @@ def get_permissions(self): return [] return super().get_permissions() + @transaction.atomic def update(self, request, *args, **kwargs): """ course를 수정합니다. @@ -71,22 +81,16 @@ def update(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) course = self.get_object() - self.perform_update(serializer, course) + self.update_course_with_lectures_and_topics( + course, + serializer.validated_data, + serializer.validated_data.get("lectures", []), + ) if getattr(course, "_prefetched_objects_cache", None): course._prefetched_objects_cache = {} serializer = self.get_serializer(course) return Response(serializer.data) - @transaction.atomic - def perform_update(self, serializer, course): - """ - course 및 하위 모델 lecture, topic, assignment, quiz 등을 함께 수정합니다. - """ - - self.update_course_with_lectures_and_topics( - course, serializer.data, serializer.data.get("lectures", []) - ) - @extend_schema_view( get=extend_schema( @@ -108,6 +112,15 @@ class CourseListCreateView(CourseMixin, generics.ListCreateAPIView): queryset = Course.objects.all() permission_classes = [IsStaffOrReadOnly] + pagination_class = CourseResultsSetPagination + filter_backends = [ + DjangoFilterBackend, + filters.SearchFilter, + filters.OrderingFilter, + ] + search_fields = ["title", "short_description", "description"] + filterset_fields = ["category", "skill_level"] + ordering_fields = ["created_at", "price"] def get_serializer_class(self): """ @@ -122,14 +135,23 @@ def get_serializer_class(self): return CourseSummarySerializer @transaction.atomic - def perform_create(self, serializer): + def create(self, request, *args, **kwargs): """ - course 및 하위 모델 lecture, topic, assignment, quiz 등을 함께 생성합니다. + course를 생성합니다. """ - self.create_course_with_lectures_and_topics( - serializer.data, serializer.data.get("lectures", []) + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + author = self.request.user if self.request.user.is_staff else None + + course = self.create_course_with_lectures_and_topics( + serializer.validated_data, + serializer.validated_data.get("lectures", []), + author, ) + serializer = self.get_serializer(course) + headers = self.get_success_headers(serializer.data) + return Response(serializer.data, status=201, headers=headers) @extend_schema_view( @@ -153,6 +175,14 @@ class CurriculumListCreateView(generics.ListCreateAPIView): queryset = Curriculum.objects.all() serializer_class = CurriculumSummarySerializer permission_classes = [IsStaffOrReadOnly] + filter_backends = [ + DjangoFilterBackend, + filters.SearchFilter, + filters.OrderingFilter, + ] + search_fields = ["title", "description"] + filterset_fields = ["category", "skill_level"] + ordering_fields = ["created_at", "price"] def get_serializer_class(self): """ @@ -171,10 +201,12 @@ def perform_create(self, serializer): """ curriculum을 생성합니다. """ + author = self.request.user if self.request.user.is_staff else None curriculum = Curriculum.objects.create( name=serializer.data.get("name"), description=serializer.data.get("description"), price=serializer.data.get("price"), + author=author, ) courses_ids = serializer.data.get("courses_ids", []) Course.objects.filter(id__in=courses_ids).update(curriculum=curriculum) diff --git a/jwtauth/authentication.py b/jwtauth/authentication.py index 6a52414..f6a6f1e 100644 --- a/jwtauth/authentication.py +++ b/jwtauth/authentication.py @@ -1,19 +1,16 @@ from rest_framework.authentication import BaseAuthentication from rest_framework.exceptions import AuthenticationFailed from django.contrib.auth import get_user_model -import jwt +import jwt, logging +from django.core.cache import cache from django.conf import settings +logger = logging.getLogger(__name__) User = get_user_model() class JWTAuthentication(BaseAuthentication): - """ - 해당 클래스는 JWT 토큰을 사용하여 사용자를 인증하는 데 사용됩니다. - - 토큰이 유효하지 않으면 해당하는 메시지를 반환합니다. - """ - def authenticate(self, request): auth_header = request.headers.get("Authorization") if not auth_header: @@ -25,19 +22,38 @@ def authenticate(self, request): access_token, settings.SECRET_KEY, algorithms=["HS256"] ) - user_id = payload.get("user_id") - user = User.objects.get(id=user_id) + user_id = payload["user_id"] + + cache_key = f"user_{user_id}" + user_data = cache.get(cache_key) + + if user_data is None: + user = User.objects.get(id=user_id) + user_data = { + "id": user.id, + "email": user.email, + "is_staff": user.is_staff, + "is_superuser": user.is_superuser, + } + cache.set(cache_key, user_data, timeout=18000) + + user = User( + id=user_data["id"], + email=user_data["email"], + is_staff=user_data["is_staff"], + is_superuser=user_data["is_superuser"], + ) return (user, None) except jwt.ExpiredSignatureError: raise AuthenticationFailed("토큰이 만료되었습니다!") except IndexError: - raise AuthenticationFailed("토큰이 유효하지 않습니다!") + raise AuthenticationFailed("토큰이 없습니다!") except jwt.DecodeError: - raise AuthenticationFailed("토큰 디코딩 오류!") + raise AuthenticationFailed("토큰이 유효하지 않습니다!") + except User.DoesNotExist: + raise AuthenticationFailed("유효하지 않은 사용자입니다!") except Exception as e: - raise AuthenticationFailed(f"인증 오류: {str(e)}") - - def authenticate_header(self, request): - return "Bearer" + logger.error(f"인증 오류: {str(e)}") + raise AuthenticationFailed("인증이 유효하지 않습니다!") diff --git a/jwtauth/test/test_authentication.py b/jwtauth/test/test_authentication.py index d112fca..8fc4f9b 100644 --- a/jwtauth/test/test_authentication.py +++ b/jwtauth/test/test_authentication.py @@ -1,5 +1,4 @@ import pytest -from django.contrib.auth import get_user_model from rest_framework.test import APIClient from rest_framework import status from django.utils import timezone @@ -9,8 +8,7 @@ from django.conf import settings from jwtauth.models import BlacklistedToken from jwtauth.utils.token_generator import generate_access_token, generate_refresh_token - -User = get_user_model() +from accounts.models import CustomUser as User @pytest.fixture @@ -18,7 +16,9 @@ def api_client(): """ API 클라이언트를 생성하여 반환합니다. """ - return APIClient() + client = APIClient() + client.default_format = "json" + return client @pytest.fixture @@ -26,7 +26,9 @@ def user(db): """ 테스트용 유저를 생성하여 반환합니다. """ - return User.objects.create_user(email="test@example.com", password="testpass123") + return User.objects.create_user( + email="test@example.com", password="testpass123", nickname="testuser" + ) @pytest.fixture @@ -51,12 +53,12 @@ def test_로그인_성공(api_client, user): # Given: 유효한 사용자 정보가 있음 # When: 로그인 API에 POST 요청을 보냄 response = api_client.post( - "/api/login/", {"email": "test@example.com", "password": "testpass123"} + reverse("login"), {"email": "test@example.com", "password": "testpass123"} ) # Then: 응답 상태 코드가 200이고, 액세스 토큰과 리프레시 토큰이 포함되어 있음 assert response.status_code == status.HTTP_200_OK assert "access_token" in response.data - assert "refresh_token" in response.data + assert "refresh_token" in response.cookies @pytest.mark.django_db @@ -65,7 +67,7 @@ def test_로그인_실패(api_client): # Given: 잘못된 사용자 정보가 있음 # When: 로그인 API에 잘못된 정보로 POST 요청을 보냄 response = api_client.post( - "/api/login/", {"email": "wrong@example.com", "password": "wrongpass"} + reverse("login"), {"email": "wrong@example.com", "password": "wrongpass"} ) # Then: 응답 상태 코드가 401 (Unauthorized)임 assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -77,7 +79,7 @@ def test_로그아웃_성공(api_client, user, refresh_token): # Given: 인증된 사용자와 유효한 리프레시 토큰이 있음 api_client.force_authenticate(user=user) # When: 로그아웃 API에 리프레시 토큰과 함께 POST 요청을 보냄 - response = api_client.post("/api/logout/", {"refresh_token": refresh_token}) + response = api_client.post(reverse("logout"), {"refresh_token": refresh_token}) # Then: 응답 상태 코드가 200이고, 리프레시 토큰이 블랙리스트에 추가됨 assert response.status_code == status.HTTP_200_OK assert BlacklistedToken.objects.filter(token=refresh_token).exists() @@ -89,7 +91,7 @@ def test_로그아웃_실패_토큰없음(api_client, user): # Given: 인증된 사용자가 있지만 리프레시 토큰이 없음 api_client.force_authenticate(user=user) # When: 로그아웃 API에 리프레시 토큰 없이 POST 요청을 보냄 - response = api_client.post("/api/logout/", {}) + response = api_client.post(reverse("logout"), {}) # Then: 응답 상태 코드가 400 (Bad Request)임 assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -99,7 +101,7 @@ def test_리프레시_토큰_갱신_성공(api_client, user, refresh_token): """리프레시 토큰 갱신 API를 테스트합니다.""" # Given: 유효한 리프레시 토큰이 있음 # When: 리프레시 API에 리프레시 토큰과 함께 POST 요청을 보냄 - response = api_client.post("/api/refresh/", {"refresh_token": refresh_token}) + response = api_client.post(reverse("refresh"), {"refresh_token": refresh_token}) # Then: 응답 상태 코드가 200이고, 새로운 액세스 토큰과 리프레시 토큰이 반환되며, 기존 리프레시 토큰이 블랙리스트에 추가됨 assert response.status_code == status.HTTP_200_OK assert "access_token" in response.data @@ -115,7 +117,7 @@ def test_리프레시_토큰_갱신_실패_블랙리스트(api_client, user, ref token=refresh_token, user=user, token_type="refresh" ) # When: 리프레시 API에 블랙리스트에 등록된 리프레시 토큰과 함께 POST 요청을 보냄 - response = api_client.post("/api/refresh/", {"refresh_token": refresh_token}) + response = api_client.post(reverse("refresh"), {"refresh_token": refresh_token}) # Then: 응답 상태 코드가 400 (Bad Request)임 assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -140,8 +142,8 @@ def test_JWT_인증_실패_만료된_토큰(api_client, user): url = reverse("refresh") # When: 리프레시 API에 만료된 토큰과 함께 POST 요청을 보냄 response = api_client.post(url) - # Then: 응답 상태 코드가 400 (Bad Request)임 - assert response.status_code == status.HTTP_400_BAD_REQUEST + # Then: 응답 상태 코드가 403 (Forbidden)임 + assert response.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.django_db @@ -152,8 +154,8 @@ def test_JWT_인증_실패_유효하지_않은_토큰(api_client): url = reverse("refresh") # When: 리프레시 API에 유효하지 않은 토큰과 함께 POST 요청을 보냄 response = api_client.post(url) - # Then: 응답 상태 코드가 400 (Bad Request)임 - assert response.status_code == status.HTTP_400_BAD_REQUEST + # Then: 응답 상태 코드가 403 (Forbidden)임 + assert response.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.django_db diff --git a/jwtauth/urls.py b/jwtauth/urls.py index c00ee26..06750fe 100644 --- a/jwtauth/urls.py +++ b/jwtauth/urls.py @@ -1,8 +1,10 @@ from django.urls import path -from .views import LoginView, LogoutView, RefreshTokenView +from .views import LoginView, LogoutView, RefreshTokenView, GoogleLogin + urlpatterns = [ path("login/", LoginView.as_view(), name="login"), path("logout/", LogoutView.as_view(), name="logout"), path("refresh/", RefreshTokenView.as_view(), name="refresh"), + path("social-login/google/", GoogleLogin.as_view(), name="google_login"), ] diff --git a/jwtauth/utils/token_generator.py b/jwtauth/utils/token_generator.py index e79db64..5d22f5b 100644 --- a/jwtauth/utils/token_generator.py +++ b/jwtauth/utils/token_generator.py @@ -1,5 +1,6 @@ -import jwt from datetime import timedelta + +import jwt from django.conf import settings from django.utils import timezone @@ -8,10 +9,15 @@ def generate_access_token(user): """ 사용자 정보를 받아서 access token을 생성합니다. """ + payload = { "user_id": user.id, "is_staff": user.is_staff, "is_superuser": user.is_superuser, + "iat": timezone.now(), + "nickname": user.nickname, + "email": user.email, + "image": user.get_image_url(), "exp": timezone.now() + timedelta(minutes=30), } return jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256") diff --git a/jwtauth/views.py b/jwtauth/views.py index d6598d5..cf50ec2 100644 --- a/jwtauth/views.py +++ b/jwtauth/views.py @@ -2,13 +2,24 @@ from rest_framework.response import Response from rest_framework.permissions import IsAuthenticated, AllowAny from rest_framework import status +from dj_rest_auth.registration.views import SocialLoginView +from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter +from allauth.socialaccount.providers.oauth2.client import OAuth2Client from django.contrib.auth import authenticate, get_user_model from django.conf import settings -from .serializers import LoginSerializer, LogoutSerializer, RefreshTokenSerializer -from .utils.token_generator import generate_access_token, generate_refresh_token +from .serializers import ( + LoginSerializer, + LogoutSerializer, + RefreshTokenSerializer, +) +from .utils.token_generator import ( + generate_access_token, + generate_refresh_token, +) from .models import BlacklistedToken import jwt, logging + logger = logging.getLogger(__name__) User = get_user_model() @@ -36,10 +47,15 @@ def post(self, request): access_token = generate_access_token(user) refresh_token = generate_refresh_token(user) - return Response( - {"access_token": access_token, "refresh_token": refresh_token} + response = Response({"access_token": access_token}) + response.set_cookie( + key="refresh_token", + value=refresh_token, + httponly=True, + secure=not settings.DEBUG, + samesite="None", ) - + return response else: return Response( {"error": "회원 가입하세요"}, status=status.HTTP_401_UNAUTHORIZED @@ -64,7 +80,9 @@ def post(self, request): refresh_token = serializer.validated_data["refresh_token"] try: - BlacklistedToken.objects.create(token=refresh_token, user=request.user) + BlacklistedToken.objects.create( + token=refresh_token, user=request.user, token_type="refresh" + ) return Response( {"success": "로그아웃 완료."}, status=status.HTTP_200_OK, @@ -136,3 +154,26 @@ def post(self, request): ) else: return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class GoogleLogin(SocialLoginView): + adapter_class = GoogleOAuth2Adapter + callback_url = settings.GOOGLE_CALLBACK_URL + client_class = OAuth2Client + + def get_response(self): + response = super().get_response() + user = self.user + access_token = generate_access_token(user) + refresh_token = generate_refresh_token(user) + + response.set_cookie( + key="refresh_token", + value=refresh_token, + httponly=True, + secure=not settings.DEBUG, + samesite="None", + ) + response.data = {"access_token": access_token} + + return response diff --git a/materials/admin.py b/materials/admin.py index 8c38f3f..f2b6b26 100644 --- a/materials/admin.py +++ b/materials/admin.py @@ -1,3 +1,6 @@ from django.contrib import admin -# Register your models here. +from .models import Image, Video + +admin.site.register(Image) +admin.site.register(Video) diff --git a/materials/migrations/0003_image_image_image_user_alter_image_course_and_more.py b/materials/migrations/0003_image_image_image_user_alter_image_course_and_more.py new file mode 100644 index 0000000..08324e9 --- /dev/null +++ b/materials/migrations/0003_image_image_image_user_alter_image_course_and_more.py @@ -0,0 +1,48 @@ +# Generated by Django 5.1.1 on 2024-10-13 06:00 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('courses', '0007_alter_course_category'), + ('materials', '0002_image_video_delete_blacklistedtoken'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AddField( + model_name='image', + name='image', + field=models.ImageField(blank=True, null=True, upload_to='images/'), + ), + migrations.AddField( + model_name='image', + name='user', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='image', to=settings.AUTH_USER_MODEL), + ), + migrations.AlterField( + model_name='image', + name='course', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='image', to='courses.course'), + ), + migrations.AlterField( + model_name='video', + name='topic', + field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='video', to='courses.topic'), + ), + migrations.CreateModel( + name='VideoEventData', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('event_type', models.CharField(choices=[('pause', 'Paused'), ('ended', 'Ended'), ('leave', 'Left Page')], max_length=20, verbose_name='이벤트 유형')), + ('duration', models.FloatField(verbose_name='비디오 전체 길이')), + ('current_time', models.FloatField(verbose_name='현재 재생 위치')), + ('timestamp', models.DateTimeField(auto_now_add=True, verbose_name='이벤트 발생 시간')), + ('video', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='video_event_datas', to='materials.video', verbose_name='해당 비디오')), + ], + ), + ] diff --git a/materials/migrations/0004_remove_image_file_remove_image_image_and_more.py b/materials/migrations/0004_remove_image_file_remove_image_image_and_more.py new file mode 100644 index 0000000..269a662 --- /dev/null +++ b/materials/migrations/0004_remove_image_file_remove_image_image_and_more.py @@ -0,0 +1,65 @@ +# Generated by Django 5.1.1 on 2024-10-13 09:16 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('courses', '0010_remove_topic_description'), + ('materials', '0003_image_image_image_user_alter_image_course_and_more'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.RemoveField( + model_name='image', + name='file', + ), + migrations.RemoveField( + model_name='image', + name='image', + ), + migrations.RemoveField( + model_name='image', + name='title', + ), + migrations.RemoveField( + model_name='video', + name='file', + ), + migrations.RemoveField( + model_name='video', + name='title', + ), + migrations.AddField( + model_name='image', + name='image_url', + field=models.URLField(default='', verbose_name='이미지 파일'), + preserve_default=False, + ), + migrations.AddField( + model_name='video', + name='course', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='video', to='courses.course'), + ), + migrations.AddField( + model_name='video', + name='video_url', + field=models.URLField(default='', verbose_name='비디오 파일'), + preserve_default=False, + ), + migrations.AddField( + model_name='videoeventdata', + name='user', + field=models.ForeignKey(default='', on_delete=django.db.models.deletion.CASCADE, related_name='video_event_datas', to=settings.AUTH_USER_MODEL, verbose_name='시청 기록의 해당 사용자'), + preserve_default=False, + ), + migrations.AlterField( + model_name='videoeventdata', + name='video', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='video_event_datas', to='materials.video', verbose_name='시청 기록의 해당 비디오'), + ), + ] diff --git a/materials/migrations/0005_image_author.py b/materials/migrations/0005_image_author.py new file mode 100644 index 0000000..20b01d1 --- /dev/null +++ b/materials/migrations/0005_image_author.py @@ -0,0 +1,22 @@ +# Generated by Django 5.1.1 on 2024-10-13 11:34 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('materials', '0004_remove_image_file_remove_image_image_and_more'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AddField( + model_name='image', + name='author', + field=models.ForeignKey(default=1, on_delete=django.db.models.deletion.CASCADE, related_name='images', to=settings.AUTH_USER_MODEL, verbose_name='이미지를 등록한 사용자'), + preserve_default=False, + ), + ] diff --git a/materials/migrations/0006_alter_video_topic.py b/materials/migrations/0006_alter_video_topic.py new file mode 100644 index 0000000..1730b34 --- /dev/null +++ b/materials/migrations/0006_alter_video_topic.py @@ -0,0 +1,20 @@ +# Generated by Django 5.1.1 on 2024-10-13 11:43 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('courses', '0010_remove_topic_description'), + ('materials', '0005_image_author'), + ] + + operations = [ + migrations.AlterField( + model_name='video', + name='topic', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='video', to='courses.topic'), + ), + ] diff --git a/materials/migrations/0007_alter_image_author.py b/materials/migrations/0007_alter_image_author.py new file mode 100644 index 0000000..69afed1 --- /dev/null +++ b/materials/migrations/0007_alter_image_author.py @@ -0,0 +1,21 @@ +# Generated by Django 5.1.1 on 2024-10-13 12:05 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('materials', '0006_alter_video_topic'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AlterField( + model_name='image', + name='author', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='images', to=settings.AUTH_USER_MODEL, verbose_name='이미지를 등록한 사용자'), + ), + ] diff --git a/materials/models.py b/materials/models.py index 2986254..9ebc840 100644 --- a/materials/models.py +++ b/materials/models.py @@ -1,26 +1,128 @@ -from courses.models import Course, Topic +import uuid + from django.db import models +from accounts.models import CustomUser +from courses.models import Course, Topic + + +def upload_to(instance, filename): + """ + ImageField를 통해 파일이 업로드될 때 해당 파일의 저장 경로를 동적으로 생성합니다. + - 모델 인스턴스가 save() 호출될 때, 파일이 저장되기 전 upload_to에 정의된 경로를 생성하기 위해 호출됩니다. + - ImageField의 upload_to 인자로 전달됩니다. + - 생성된 경로를 반환하며, 이 경로는 Django가 해당 파일을 저장할 때 사용됩니다. + - (장점) 사용자 접근성을 높이면서 중복 파일 이름 문제를 해결합니다. + """ + ext = filename.split(".")[-1] + return f"images/{uuid.uuid4()}.{ext}" + class Image(models.Model): + """ + 이미지 객체를 위해 작성된 모델입니다. + """ + course = models.OneToOneField( - Course, on_delete=models.CASCADE, related_name="images" + Course, + on_delete=models.CASCADE, + related_name="image", + null=True, + blank=True, ) - title = models.CharField(max_length=255, verbose_name="이미지 제목") - file = models.ImageField(upload_to="images/", verbose_name="이미지 파일") + user = models.OneToOneField( + CustomUser, + on_delete=models.CASCADE, + related_name="image", + null=True, + blank=True, + ) + author = models.ForeignKey( + CustomUser, + on_delete=models.CASCADE, + related_name="images", + verbose_name="이미지를 등록한 사용자", + null=True, + blank=True, + ) + image_url = models.URLField(verbose_name="이미지 파일") created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) def __str__(self): - return f"{self.topic.title} - {self.title}" + if self.user: + return f"{self.user}'s Image" + elif self.course: + return f"Course Image for {self.course}" + return "Image" + + def save(self, *args, **kwargs): + + if self.user and not self.image_url: + self.image_url = "images/default_user_image.jpg" + if self.course and not self.image_url: + self.image_url = "images/default_course_image.jpg" + super().save(*args, **kwargs) class Video(models.Model): - topic = models.OneToOneField(Topic, on_delete=models.CASCADE, related_name="videos") - title = models.CharField(max_length=255, verbose_name="비디오 제목") - file = models.FileField(upload_to="videos/", verbose_name="비디오 파일") + topic = models.OneToOneField( + Topic, on_delete=models.CASCADE, related_name="video", null=True, blank=True + ) + course = models.OneToOneField( + Course, on_delete=models.CASCADE, related_name="video", null=True, blank=True + ) + video_url = models.URLField(verbose_name="비디오 파일") created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) def __str__(self): - return f"{self.topic.title} - {self.title}" + return f"{self.id}" + + +class VideoEventData(models.Model): + EVENT_CHOICES = [ + ("pause", "Paused"), + ("ended", "Ended"), + ("leave", "Left Page"), + ] + + user = models.ForeignKey( + CustomUser, + on_delete=models.CASCADE, + related_name="video_event_datas", + verbose_name="시청 기록의 해당 사용자", + ) + + video = models.ForeignKey( + Video, + on_delete=models.CASCADE, + related_name="video_event_datas", + verbose_name="시청 기록의 해당 비디오", + ) + + event_type = models.CharField( + max_length=20, choices=EVENT_CHOICES, verbose_name="이벤트 유형" + ) + duration = models.FloatField(verbose_name="비디오 전체 길이") # 초 단위로 저장 + current_time = models.FloatField(verbose_name="현재 재생 위치") # 초 단위로 저장 + timestamp = models.DateTimeField(auto_now_add=True, verbose_name="이벤트 발생 시간") + + def get_duration_in_minutes(self): + """ + 분과 초로 변환된 영상 재생 시간을 반환합니다. + """ + minutes = int(self.duration // 60) + seconds = int(self.duration % 60) + return f"{minutes}분 {seconds}초" + + def get_current_time_in_minutes(self): + """ + 분과 초로 변환된 현재 재생 시간을 반환합니다. + """ + minutes = int(self.current_time // 60) + seconds = int(self.current_time % 60) + return f"{minutes}분 {seconds}초" + + def __str__(self): + return f"{self.event_type} at {self.get_current_time_in_minutes()}/{self.get_duration_in_minutes()}" diff --git a/materials/serializers.py b/materials/serializers.py new file mode 100644 index 0000000..bf67290 --- /dev/null +++ b/materials/serializers.py @@ -0,0 +1,175 @@ +import cv2 +from PIL import Image as PILImage +from rest_framework import serializers + +from .models import Image, Video, VideoEventData + + +class ImageSerializer(serializers.ModelSerializer): + """ + 이미지 생성(업로드)을 위한 시리얼라이저입니다. + - 형식, 손상 여부에 대해 유효성 검사를 합니다. + """ + + file = serializers.ImageField(write_only=True) + + class Meta: + model = Image + fields = [ + "id", + "image_url", + "created_at", + "updated_at", + "file", + ] + read_only_fields = [ + "id", + "created_at", + "updated_at", + "image_url", + ] + + def validate_file(self, value): + allowed_image_extensions = (".png", ".jpg", ".jpeg") + + if not value.name.endswith(allowed_image_extensions): + raise serializers.ValidationError( + "지원하지 않는 파일 형식입니다. PNG, JPG, JPEG만 가능합니다." + ) + try: + img = PILImage.open(value) + img.verify() + except Exception: + raise serializers.ValidationError("유효한 이미지 파일이 아닙니다.") + return value + + +class VideoSerializer(serializers.ModelSerializer): + + file = serializers.FileField(write_only=True) + + class Meta: + model = Video + fields = [ + "id", + "video_url", + "created_at", + "updated_at", + "file", + ] + read_only_fields = ["id", "created_at", "updated_at", "video_url"] + + def validate_file(self, value): + # 영상 형식과 크기 유효성 검사 + allowed_extensions = ["mp4", "avi", "mov", "wmv"] + max_size = 100 * 1024 * 1024 # 100MB + + if not value.name.split(".")[-1] in allowed_extensions: + raise serializers.ValidationError( + f"허용되지 않는 파일 형식입니다. 다음 형식만 가능합니다: {', '.join(allowed_extensions)}." + ) + if value.size > max_size: + raise serializers.ValidationError( + "파일 크기가 너무 큽니다. 최대 크기는 100MB입니다." + ) + # 영상 손상 여부 검사 + # try: + # cap = cv2.VideoCapture(value) + # if not cap.isOpened(): + # raise serializers.ValidationError( + # "비디오 파일을 열 수 없습니다. 파일이 손상되었을 수 있습니다." + # ) + + # ret, frame = cap.read() + # if not ret: + # raise serializers.ValidationError( + # "비디오 파일을 읽을 수 없습니다. 파일이 손상되었을 수 있습니다." + # ) + + # except Exception as e: + # raise serializers.ValidationError( + # f"비디오 파일 검사 중 오류가 발생했습니다: {str(e)}" + # ) + # finally: + # cap.release() + + return value + + +class VideoEventSerializer(serializers.ModelSerializer): + + video_url = serializers.URLField( + write_only=True + ) # 클라이언트가 보내는 video_url을 받기 위한 필드 + + class Meta: + model = VideoEventData + fields = ["id", "video_url", "duration", "current_time", "event_type"] + read_only_fields = ["id", "timestamp"] + + def validate_duration(self, value): + if value < 0: + raise serializers.ValidationError("올바른 영상 재생시간이 아닙니다.") + return value + + def validate_current_time(self, value): # 필드 이름 수정 + if value < 0: + raise serializers.ValidationError("올바른 영상 현재 재생시간이 아닙니다.") + return value + + def validate_event_type(self, value): + """ + event_type 값이 선택된 EVENT_CHOICES 중 하나인지 확인합니다. + """ + valid_choices = dict(VideoEventData.EVENT_CHOICES).keys() + if value not in valid_choices: + raise serializers.ValidationError(f"{value}는 유효한 이벤트가 아닙니다.") + return value + + def validate(self, data): + if data["duration"] < data["current_time"]: + raise serializers.ValidationError("올바른 영상 재생시간 관계가 아닙니다.") + return data + + def create(self, validated_data): + """ + VideoEventData를 생성합니다. + - VideoEventData를 받기 전에, 먼저 S3에 존재하는 Video 인스턴스인지 확인하고 + - 그러하다면 해당 Video 인스턴스와 관계된 VideoEventData를 생성합니다. + + """ + video_url = validated_data.pop("video_url", None) + + try: + video_instance = Video.objects.get(file=video_url) + except Video.DoesNotExist: + raise serializers.ValidationError( + "해당 URL과 일치하는 영상 파일이 없습니다." + ) + video_event_data = VideoEventData.objects.create( + video=video_instance, **validated_data + ) + + return video_event_data + + +class UserViewEventListSerializer(serializers.ModelSerializer): + duration_in_minutes = serializers.SerializerMethodField() + current_time_in_minutes = serializers.SerializerMethodField() + + class Meta: + model = VideoEventData + fields = [ + "event_type", + "duration", + "current_time", + "timestamp", + "duration_in_minutes", # 분과 초로 변환된 전체 재생시간 + "current_time_in_minutes", # 분과 초로 변환된 현재 시간 + ] + + def get_duration_in_minutes(self, obj): + return obj.get_duration_in_minutes() + + def get_current_time_in_minutes(self, obj): + return obj.get_current_time_in_minutes() diff --git a/materials/urls.py b/materials/urls.py index 68dea8b..486731c 100644 --- a/materials/urls.py +++ b/materials/urls.py @@ -1,6 +1,35 @@ -from django.contrib import admin from django.urls import path +from . import views +app_name = "materials" -urlpatterns = [] \ No newline at end of file +urlpatterns = [ + # 이미지 관련 URL + path("images/upload/", views.ImageCreateView.as_view(), name="image-upload"), + path("images/", views.ImageListCreateView.as_view(), name="image-list-create"), + path( + "images//", + views.ImageRetrieveUpdateDestroyView.as_view(), + name="image-detail", + ), + # 비디오 관련 URL + path("videos/upload/", views.VideoCreateView.as_view(), name="video-upload"), + path("videos/", views.VideoListCreateView.as_view(), name="video-list"), + path( + "videos//", + views.VideoRetrieveUpdateDestroyView.as_view(), + name="video-detail", + ), + # 사용자 비디오 시청 기록 관련 URL + path( + "video-event-data/", + views.VideoEventCreateView.as_view(), + name="video-event-data", + ), + path( + "users//videos//watch-history/", + views.UserVideoEventListView.as_view(), + name="video-event-list", + ), +] diff --git a/materials/views.py b/materials/views.py index 91ea44a..bae2c30 100644 --- a/materials/views.py +++ b/materials/views.py @@ -1,3 +1,252 @@ -from django.shortcuts import render +import io -# Create your views here. +import boto3 +from botocore.exceptions import ClientError +from django.conf import settings +from django.contrib.auth import get_user_model +from django.shortcuts import get_object_or_404 +from PIL import Image as PILImage +from PIL import ImageFilter +from rest_framework import generics, status +from rest_framework.exceptions import PermissionDenied +from rest_framework.parsers import FormParser, MultiPartParser +from rest_framework.response import Response + +from courses.models import Course + +from .models import Image, Video, VideoEventData +from .serializers import ( + ImageSerializer, + UserViewEventListSerializer, + VideoEventSerializer, + VideoSerializer, +) + +User = get_user_model() + + +# 리팩토링할 때 중복 함수 이곳에 작성 +def optimize_image(image_file): + """ + 이미지를 최적화하는 메서드입니다. + - 포맷 변환 + - 리사이징 + - 필터링 + """ + # Pillow를 사용하여 이미지 열기 + img = PILImage.open(image_file) + + # 포맷 변환 + img = img.convert("RGB") + + # 리사이징 + img.thumbnail((800, 600)) + + # 이미지 필터링: 샤프닝 필터 적용 + img = img.filter(ImageFilter.SHARPEN) + + return img + + +class ImageCreateView(generics.CreateAPIView): + # POST 요청: Image 객체를 사용해서 S3에 이미지 파일을 업로드합니다. + + queryset = Image.objects.all() + serializer_class = ImageSerializer + permission_classes = [] + parser_classes = (MultiPartParser, FormParser) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + image_file = request.FILES.get("file") + + if not image_file: + return Response( + {"error": "이미지 파일이 필요합니다."}, + status=status.HTTP_400_BAD_REQUEST, + ) + + optimized_image = optimize_image(image_file) + + try: + # 최적화된 이미지를 임시로 메모리에 저장 + image_io = io.BytesIO() + optimized_image.save(image_io, format="JPEG", quality=85) + image_io.seek(0) + + # S3에 파일 업로드 + s3_client = boto3.client( + "s3", + aws_access_key_id=settings.AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, + region_name=settings.AWS_S3_REGION_NAME, + ) + + # 파일 이름 생성: 사용자 ID와 코스 ID를 포함 + user = request.user + if request.user: + file_name = f"images/user_{user.id}/{image_file.name}" + else: + return Response( + {"error": "유효한 사용자 또는 코스가 필요합니다."}, + status=status.HTTP_400_BAD_REQUEST, + ) + + s3_client.upload_fileobj( + image_io, + settings.AWS_STORAGE_BUCKET_NAME, + file_name, + ExtraArgs={"ContentType": "image/jpeg"}, + ) + + # 업로드된 파일의 URL 생성 + file_url = f"https://{settings.AWS_S3_CUSTOM_DOMAIN}/{file_name}" + + # 시리얼라이저에 전달 후 저장 + image = Image.objects.create(image_url=file_url, author=user) + + return Response( + self.get_serializer(image).data, status=status.HTTP_201_CREATED + ) + except ClientError as e: + return Response( + {"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + + +class ImageListCreateView(generics.ListCreateAPIView): + # GET 요청: 이미지 파일 목록을 가져옵니다. + + queryset = Image.objects.all() + serializer_class = ImageSerializer + permission_classes = [] + + def perform_create(self, serializer): + serializer.save(course_id=self.request.data.get("course_id")) + + +class ImageRetrieveUpdateDestroyView(generics.RetrieveUpdateDestroyAPIView): + # GET 요청: 특정 이미지 파일을 조회합니다. + # PUT 요청: 특정 이미지 파일을 변경합니다. + # DELETE 요청: 특정 이미지 파일을 삭제합니다. + + queryset = Image.objects.all() + serializer_class = ImageSerializer + permission_classes = [] + + def check_object_permissions(self, request, obj): + if not request.user.is_staff and obj.course.tutor != request.user: + raise PermissionDenied("접근 권한이 없습니다.") + return super().check_object_permissions(request, obj) + + +class VideoCreateView(generics.CreateAPIView): + # POST 요청: 영상 파일을 업로드합니다. + + queryset = Video.objects.all() + serializer_class = VideoSerializer + permission_classes = [] + parser_classes = (MultiPartParser, FormParser) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + if serializer.is_valid(): + file = request.FILES.get("file") + if not file: + return Response( + {"error": "No file provided"}, status=status.HTTP_400_BAD_REQUEST + ) + + max_image_size = 5 * 1024 * 1024 # 5MB + + if file.size > max_image_size: + raise serializer.ValidationError( + "파일 크기는 5MB를 초과할 수 없습니다." + ) + + # S3 클라이언트 설정 + s3_client = boto3.client( + "s3", + aws_access_key_id=settings.AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, + region_name=settings.AWS_S3_REGION_NAME, + ) + + try: + # S3에 파일 업로드 + file_name = f"videos/{file.name}" + s3_client.upload_fileobj( + file, + settings.AWS_STORAGE_BUCKET_NAME, + file_name, + ExtraArgs={"ContentType": file.content_type}, + ) + + # 업로드된 파일의 URL 생성 + file_url = f"https://{settings.AWS_S3_CUSTOM_DOMAIN}/{file_name}" + + # 비디오 객체 생성 및 저장 + video = Video.objects.create(video_url=file_url) + + return Response( + self.get_serializer(video).data, status=status.HTTP_201_CREATED + ) + except ClientError as e: + return Response( + {"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class VideoListCreateView(generics.ListCreateAPIView): + # GET 요청: 영상 파일 목록을 조회합니다. + + queryset = Video.objects.all() + serializer_class = VideoSerializer + permission_classes = [] + + def perform_create(self, serializer): + serializer.save(topic_id=self.request.data.get("topic_id")) + + +class VideoRetrieveUpdateDestroyView(generics.RetrieveUpdateDestroyAPIView): + # GET 요청: 특정 영상 파일을 조회합니다. + # PUT 요청: 특정 영상 파일을 변경합니다. + # DELETE 요청: 특정 영상 파일을 삭제합니다. + + queryset = Video.objects.all() + serializer_class = VideoSerializer + permission_classes = [] + + def check_object_permissions(self, request, obj): + if request.method in ["PUT", "PATCH", "DELETE"]: + if not request.user.is_staff and obj.topic.course.tutor != request.user: + raise PermissionDenied("접근 권한이 없습니다.") + return super().check_object_permissions(request, obj) + + def retrieve(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance) + return Response(serializer.data) + + +class VideoEventCreateView(generics.CreateAPIView): + """ + POST 요청을 받아 영상 파일에 대한 이벤트 정보를 저장합니다. + """ + + queryset = VideoEventData.objects.all() + serializer_class = VideoEventSerializer + + +class UserVideoEventListView(generics.ListAPIView): + serializer_class = UserViewEventListSerializer + + def get_queryset(self): + user_id = self.kwargs.get("user_id") + video_id = self.kwargs.get("video_id") + + # 특정 사용자와 특정 비디오에 대한 이벤트 데이터를 필터링 + return VideoEventData.objects.filter(user_id=user_id, video_id=video_id) diff --git a/payments/admin.py b/payments/admin.py index 8c38f3f..d9b63b3 100644 --- a/payments/admin.py +++ b/payments/admin.py @@ -1,3 +1,36 @@ from django.contrib import admin +from .models import Cart, Order, Payment, UserBillingAddress -# Register your models here. + +@admin.register(Cart) +class CartAdmin(admin.ModelAdmin): + list_display = ("user", "get_total_items", "get_total_price", "created_at") + search_fields = ("user__email", "user__nickname") + + +@admin.register(Order) +class OrderAdmin(admin.ModelAdmin): + list_display = ( + "id", + "user", + "order_status", + "get_total_items", + "get_total_price", + "created_at", + ) + list_filter = ("order_status",) + search_fields = ("user__email", "user__nickname") + + +@admin.register(Payment) +class PaymentAdmin(admin.ModelAdmin): + list_display = ("id", "user", "order", "payment_status", "amount", "paid_at") + list_filter = ("payment_status",) + search_fields = ("user__email", "user__nickname", "transaction_id") + + +@admin.register(UserBillingAddress) +class UserBillingAddressAdmin(admin.ModelAdmin): + list_display = ("user", "country", "main_address", "is_default") + list_filter = ("country", "is_default") + search_fields = ("user__email", "user__nickname", "main_address") diff --git a/payments/migrations/0004_alter_cart_options_alter_cartitem_options_and_more.py b/payments/migrations/0004_alter_cart_options_alter_cartitem_options_and_more.py new file mode 100644 index 0000000..a41e088 --- /dev/null +++ b/payments/migrations/0004_alter_cart_options_alter_cartitem_options_and_more.py @@ -0,0 +1,171 @@ +# Generated by Django 5.1.1 on 2024-10-03 08:27 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('courses', '0005_alter_assignment_options_alter_course_options_and_more'), + ('payments', '0003_alter_cart_user'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AlterModelOptions( + name='cart', + options={'verbose_name': '장바구니'}, + ), + migrations.AlterModelOptions( + name='cartitem', + options={'ordering': ['-created_at'], 'verbose_name': '장바구니 상품', 'verbose_name_plural': '장바구니 상품들'}, + ), + migrations.AddField( + model_name='cartitem', + name='course', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='courses.course', verbose_name='코스'), + ), + migrations.AlterField( + model_name='cart', + name='created_at', + field=models.DateTimeField(auto_now_add=True, verbose_name='생성일'), + ), + migrations.AlterField( + model_name='cart', + name='updated_at', + field=models.DateTimeField(auto_now=True, verbose_name='수정일'), + ), + migrations.AlterField( + model_name='cart', + name='user', + field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='사용자'), + ), + migrations.AlterField( + model_name='cartitem', + name='cart', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='cart_items', to='payments.cart', verbose_name='장바구니'), + ), + migrations.AlterField( + model_name='cartitem', + name='created_at', + field=models.DateTimeField(auto_now_add=True, verbose_name='생성일'), + ), + migrations.AlterField( + model_name='cartitem', + name='curriculum', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='courses.curriculum', verbose_name='커리큘럼'), + ), + migrations.AlterField( + model_name='cartitem', + name='quantity', + field=models.PositiveIntegerField(default=1, verbose_name='수량'), + ), + migrations.AlterField( + model_name='cartitem', + name='updated_at', + field=models.DateTimeField(auto_now=True, verbose_name='수정일'), + ), + migrations.CreateModel( + name='Coupon', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('code', models.CharField(max_length=50, unique=True, verbose_name='코드')), + ('discount_type', models.CharField(choices=[('fixed', '고정 금액 할인'), ('percentage', '정률 할인')], max_length=10, verbose_name='할인 유형')), + ('discount_value', models.PositiveIntegerField(verbose_name='할인 가치')), + ('expiration_date', models.DateTimeField(verbose_name='만료일')), + ('applicable_type', models.CharField(choices=[('curriculum', '커리큘럼'), ('course', '코스')], default='curriculum', max_length=10, verbose_name='적용 유형')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='생성일')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='수정일')), + ('course', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='courses.course', verbose_name='코스')), + ('curriculum', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='courses.curriculum', verbose_name='커리큘럼')), + ], + options={ + 'verbose_name': '쿠폰', + }, + ), + migrations.CreateModel( + name='Order', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('order_status', models.CharField(choices=[('pending', '대기 중'), ('completed', '완료됨'), ('canceled', '취소됨')], default='pending', max_length=10, verbose_name='주문 상태')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='생성일')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='수정일')), + ('user', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='사용자')), + ], + options={ + 'verbose_name': '주문', + 'ordering': ['-created_at'], + }, + ), + migrations.CreateModel( + name='OrderItem', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('quantity', models.PositiveIntegerField(default=1, verbose_name='수량')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='생성일')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='수정일')), + ('course', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='courses.course', verbose_name='코스')), + ('curriculum', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='courses.curriculum', verbose_name='커리큘럼')), + ('order', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='order_items', to='payments.order', verbose_name='주문')), + ], + options={ + 'verbose_name': '주문 상품', + 'verbose_name_plural': '주문 상품들', + 'ordering': ['-created_at'], + }, + ), + migrations.CreateModel( + name='UserBillingAddress', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('country', models.CharField(max_length=50, verbose_name='국가')), + ('main_address', models.CharField(help_text='도로명 주소 또는 지번 주소를 입력하세요.', max_length=255, verbose_name='주소')), + ('detail_address', models.CharField(help_text='아파트 동/호수, 건물 이름과 호수를 입력하세요.(선택 사항)', max_length=255, verbose_name='상세 주소')), + ('postal_code', models.CharField(max_length=50, verbose_name='우편번호')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='생성일')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='수정일')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='사용자')), + ], + options={ + 'verbose_name': '사용자 청구 주소', + }, + ), + migrations.CreateModel( + name='UserBillingInfo', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('payment_method', models.CharField(choices=[('credit_card', '신용카드'), ('kakaopay', '카카오페이'), ('bank_transfer', '계좌 이체')], max_length=20, verbose_name='결제 방식')), + ('stripe_customer_id', models.CharField(blank=True, max_length=100, null=True, verbose_name='Stripe 고객 ID')), + ('stripe_payment_method_id', models.CharField(blank=True, max_length=100, null=True, verbose_name='Stripe 결제 방법 ID')), + ('kakaopay_tid', models.CharField(blank=True, max_length=100, null=True, verbose_name='KakaoPay 거래 ID')), + ('bank_account_number', models.CharField(blank=True, max_length=100, null=True, verbose_name='계좌 번호')), + ('bank_name', models.CharField(blank=True, max_length=100, null=True, verbose_name='은행 이름')), + ('bank_transfer_status', models.CharField(choices=[('pending', '대기 중'), ('completed', '완료')], default='pending', max_length=20, verbose_name='계좌 이체 상태')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='생성일')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='수정일')), + ('billing_address', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='billing_info', to='payments.userbillingaddress', verbose_name='청구 주소')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='사용자')), + ], + options={ + 'verbose_name': '사용자 청구 정보', + 'verbose_name_plural': '사용자 청구 정보들', + 'db_table': 'user_billing_info', + 'ordering': ['-created_at'], + }, + ), + migrations.CreateModel( + name='UserPurchaseHistory', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='생성일')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='수정일')), + ('order', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='purchase_history', to='payments.order', verbose_name='주문')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='사용자')), + ], + options={ + 'verbose_name': '사용자 구매 내역', + }, + ), + ] diff --git a/payments/migrations/0005_alter_order_options_alter_order_user_payment_and_more.py b/payments/migrations/0005_alter_order_options_alter_order_user_payment_and_more.py new file mode 100644 index 0000000..164db27 --- /dev/null +++ b/payments/migrations/0005_alter_order_options_alter_order_user_payment_and_more.py @@ -0,0 +1,44 @@ +# Generated by Django 5.1.1 on 2024-10-04 08:41 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0004_alter_cart_options_alter_cartitem_options_and_more'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AlterModelOptions( + name='order', + options={'ordering': ['-created_at'], 'verbose_name': '주문', 'verbose_name_plural': '주문들'}, + ), + migrations.AlterField( + model_name='order', + name='user', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='사용자'), + ), + migrations.CreateModel( + name='Payment', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('payment_status', models.CharField(choices=[('pending', '대기 중'), ('completed', '완료됨'), ('canceled', '취소됨')], default='pending', max_length=10, verbose_name='결제 상태')), + ('payment_method', models.CharField(choices=[('credit_card', '신용카드'), ('kakaopay', '카카오페이'), ('bank_transfer', '계좌 이체')], default='credit_card', max_length=20, verbose_name='결제 방식')), + ('payment_amount', models.PositiveIntegerField(verbose_name='결제 금액')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='생성일')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='수정일')), + ('order', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to='payments.order', verbose_name='주문')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='사용자')), + ], + options={ + 'verbose_name': '결제', + }, + ), + migrations.DeleteModel( + name='Coupon', + ), + ] diff --git a/payments/migrations/0006_remove_userpurchasehistory_order_and_more.py b/payments/migrations/0006_remove_userpurchasehistory_order_and_more.py new file mode 100644 index 0000000..3636c08 --- /dev/null +++ b/payments/migrations/0006_remove_userpurchasehistory_order_and_more.py @@ -0,0 +1,101 @@ +# Generated by Django 5.1.1 on 2024-10-06 16:09 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('courses', '0005_alter_assignment_options_alter_course_options_and_more'), + ('payments', '0005_alter_order_options_alter_order_user_payment_and_more'), + ] + + operations = [ + migrations.RemoveField( + model_name='userpurchasehistory', + name='order', + ), + migrations.RemoveField( + model_name='userpurchasehistory', + name='user', + ), + migrations.AlterModelOptions( + name='userbillingaddress', + options={'verbose_name': '사용자 청구 주소', 'verbose_name_plural': '사용자 청구 주소들'}, + ), + migrations.RenameField( + model_name='payment', + old_name='payment_amount', + new_name='amount', + ), + migrations.RemoveField( + model_name='payment', + name='created_at', + ), + migrations.RemoveField( + model_name='payment', + name='updated_at', + ), + migrations.AddField( + model_name='orderitem', + name='expiry_date', + field=models.DateTimeField(blank=True, null=True, verbose_name='만료일'), + ), + migrations.AddField( + model_name='payment', + name='paid_at', + field=models.DateTimeField(blank=True, null=True, verbose_name='결제 일시'), + ), + migrations.AddField( + model_name='payment', + name='transaction_id', + field=models.CharField(blank=True, max_length=255, null=True, verbose_name='거래 ID'), + ), + migrations.AddField( + model_name='userbillingaddress', + name='is_default', + field=models.BooleanField(default=False, verbose_name='기본 주소'), + ), + migrations.AlterField( + model_name='order', + name='order_status', + field=models.CharField(choices=[('pending', '대기 중'), ('completed', '완료됨'), ('refunded', '환불됨')], default='pending', max_length=10, verbose_name='주문 상태'), + ), + migrations.AlterField( + model_name='orderitem', + name='course', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='courses.course', verbose_name='코스'), + ), + migrations.AlterField( + model_name='orderitem', + name='curriculum', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='courses.curriculum', verbose_name='커리큘럼'), + ), + migrations.AlterField( + model_name='payment', + name='payment_status', + field=models.CharField(choices=[('pending', '대기 중'), ('completed', '완료됨'), ('refunded', '환불됨')], default='pending', max_length=10, verbose_name='결제 상태'), + ), + migrations.AlterField( + model_name='userbillingaddress', + name='detail_address', + field=models.CharField(blank=True, max_length=255, verbose_name='상세 주소'), + ), + migrations.AlterField( + model_name='userbillingaddress', + name='main_address', + field=models.CharField(max_length=255, verbose_name='주소'), + ), + migrations.AlterField( + model_name='userbillingaddress', + name='postal_code', + field=models.CharField(max_length=20, verbose_name='우편번호'), + ), + migrations.DeleteModel( + name='UserBillingInfo', + ), + migrations.DeleteModel( + name='UserPurchaseHistory', + ), + ] diff --git a/payments/migrations/0007_alter_payment_options_remove_payment_payment_method_and_more.py b/payments/migrations/0007_alter_payment_options_remove_payment_payment_method_and_more.py new file mode 100644 index 0000000..e75e173 --- /dev/null +++ b/payments/migrations/0007_alter_payment_options_remove_payment_payment_method_and_more.py @@ -0,0 +1,47 @@ +# Generated by Django 5.1.1 on 2024-10-07 09:25 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0006_remove_userpurchasehistory_order_and_more'), + ] + + operations = [ + migrations.AlterModelOptions( + name='payment', + options={'verbose_name': '결제', 'verbose_name_plural': '결제들'}, + ), + migrations.RemoveField( + model_name='payment', + name='payment_method', + ), + migrations.AddField( + model_name='payment', + name='billing_address', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='payments.userbillingaddress', verbose_name='청구 주소'), + ), + migrations.AddField( + model_name='payment', + name='cancelled_at', + field=models.DateTimeField(blank=True, null=True, verbose_name='취소 일시'), + ), + migrations.AddField( + model_name='payment', + name='fail_reason', + field=models.TextField(blank=True, null=True, verbose_name='실패 사유'), + ), + migrations.AlterField( + model_name='order', + name='order_status', + field=models.CharField(choices=[('pending', '대기 중'), ('completed', '완료됨'), ('failed', '실패함'), ('cancelled', '취소됨')], default='pending', max_length=10, verbose_name='주문 상태'), + ), + migrations.AlterField( + model_name='payment', + name='payment_status', + field=models.CharField(choices=[('pending', '대기 중'), ('completed', '완료됨'), ('failed', '실패함'), ('cancelled', '취소됨')], default='pending', max_length=10, verbose_name='결제 상태'), + ), + ] diff --git a/payments/migrations/0008_alter_cartitem_course_alter_cartitem_curriculum_and_more.py b/payments/migrations/0008_alter_cartitem_course_alter_cartitem_curriculum_and_more.py new file mode 100644 index 0000000..e81c2e7 --- /dev/null +++ b/payments/migrations/0008_alter_cartitem_course_alter_cartitem_curriculum_and_more.py @@ -0,0 +1,35 @@ +# Generated by Django 5.1.1 on 2024-10-07 13:17 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('courses', '0005_alter_assignment_options_alter_course_options_and_more'), + ('payments', '0007_alter_payment_options_remove_payment_payment_method_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='cartitem', + name='course', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='courses.course', verbose_name='코스'), + ), + migrations.AlterField( + model_name='cartitem', + name='curriculum', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='courses.curriculum', verbose_name='커리큘럼'), + ), + migrations.AlterField( + model_name='orderitem', + name='course', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='courses.course', verbose_name='코스'), + ), + migrations.AlterField( + model_name='orderitem', + name='curriculum', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='courses.curriculum', verbose_name='커리큘럼'), + ), + ] diff --git a/payments/migrations/0009_alter_payment_options_and_more.py b/payments/migrations/0009_alter_payment_options_and_more.py new file mode 100644 index 0000000..77ba7e8 --- /dev/null +++ b/payments/migrations/0009_alter_payment_options_and_more.py @@ -0,0 +1,31 @@ +# Generated by Django 5.1.1 on 2024-10-07 13:48 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0008_alter_cartitem_course_alter_cartitem_curriculum_and_more'), + ] + + operations = [ + migrations.AlterModelOptions( + name='payment', + options={'ordering': ['-created_at'], 'verbose_name': '결제', 'verbose_name_plural': '결제들'}, + ), + migrations.AlterModelOptions( + name='userbillingaddress', + options={'ordering': ['-created_at'], 'verbose_name': '사용자 청구 주소', 'verbose_name_plural': '사용자 청구 주소들'}, + ), + migrations.AddField( + model_name='payment', + name='created_at', + field=models.DateTimeField(auto_now_add=True, null=True, verbose_name='생성일'), + ), + migrations.AddField( + model_name='payment', + name='updated_at', + field=models.DateTimeField(auto_now=True, verbose_name='수정일'), + ), + ] diff --git a/payments/migrations/0010_alter_payment_billing_address_alter_payment_order.py b/payments/migrations/0010_alter_payment_billing_address_alter_payment_order.py new file mode 100644 index 0000000..485e0be --- /dev/null +++ b/payments/migrations/0010_alter_payment_billing_address_alter_payment_order.py @@ -0,0 +1,24 @@ +# Generated by Django 5.1.1 on 2024-10-09 13:13 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0009_alter_payment_options_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='payment', + name='billing_address', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='payments.userbillingaddress', verbose_name='청구 주소'), + ), + migrations.AlterField( + model_name='payment', + name='order', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='payments.order', verbose_name='주문'), + ), + ] diff --git a/payments/migrations/0011_alter_order_order_status_and_more.py b/payments/migrations/0011_alter_order_order_status_and_more.py new file mode 100644 index 0000000..6bebdf9 --- /dev/null +++ b/payments/migrations/0011_alter_order_order_status_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 5.1.1 on 2024-10-10 04:53 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0010_alter_payment_billing_address_alter_payment_order'), + ] + + operations = [ + migrations.AlterField( + model_name='order', + name='order_status', + field=models.CharField(choices=[('pending', '대기 중'), ('completed', '완료됨'), ('failed', '실패함'), ('cancelled', '취소됨'), ('refunded', '환불됨')], default='pending', max_length=10, verbose_name='주문 상태'), + ), + migrations.AlterField( + model_name='payment', + name='payment_status', + field=models.CharField(choices=[('pending', '대기 중'), ('completed', '완료됨'), ('failed', '실패함'), ('cancelled', '취소됨'), ('refunded', '환불됨')], default='pending', max_length=10, verbose_name='결제 상태'), + ), + ] diff --git a/payments/migrations/0012_alter_payment_order_alter_payment_user.py b/payments/migrations/0012_alter_payment_order_alter_payment_user.py new file mode 100644 index 0000000..3412c5d --- /dev/null +++ b/payments/migrations/0012_alter_payment_order_alter_payment_user.py @@ -0,0 +1,26 @@ +# Generated by Django 5.1.1 on 2024-10-10 05:21 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0011_alter_order_order_status_and_more'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AlterField( + model_name='payment', + name='order', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='payments', to='payments.order', verbose_name='주문'), + ), + migrations.AlterField( + model_name='payment', + name='user', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='payments', to=settings.AUTH_USER_MODEL, verbose_name='사용자'), + ), + ] diff --git a/payments/migrations/0013_payment_metadata_and_more.py b/payments/migrations/0013_payment_metadata_and_more.py new file mode 100644 index 0000000..c6f727b --- /dev/null +++ b/payments/migrations/0013_payment_metadata_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 5.1.1 on 2024-10-11 02:32 + +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0012_alter_payment_order_alter_payment_user'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AddField( + model_name='payment', + name='metadata', + field=models.JSONField(blank=True, default=dict, verbose_name='메타데이터'), + ), + migrations.AddIndex( + model_name='payment', + index=models.Index(fields=['order', 'payment_status'], name='payments_pa_order_i_01ef4c_idx'), + ), + migrations.AddIndex( + model_name='payment', + index=models.Index(fields=['transaction_id'], name='payments_pa_transac_8e9d99_idx'), + ), + ] diff --git a/payments/migrations/0014_alter_cart_options_remove_payment_fail_reason.py b/payments/migrations/0014_alter_cart_options_remove_payment_fail_reason.py new file mode 100644 index 0000000..300a9a5 --- /dev/null +++ b/payments/migrations/0014_alter_cart_options_remove_payment_fail_reason.py @@ -0,0 +1,21 @@ +# Generated by Django 5.1.1 on 2024-10-12 05:24 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0013_payment_metadata_and_more'), + ] + + operations = [ + migrations.AlterModelOptions( + name='cart', + options={'ordering': ['-created_at'], 'verbose_name': '장바구니', 'verbose_name_plural': '장바구니들'}, + ), + migrations.RemoveField( + model_name='payment', + name='fail_reason', + ), + ] diff --git a/payments/migrations/0015_alter_cart_options_alter_order_options_and_more.py b/payments/migrations/0015_alter_cart_options_alter_order_options_and_more.py new file mode 100644 index 0000000..61c0cee --- /dev/null +++ b/payments/migrations/0015_alter_cart_options_alter_order_options_and_more.py @@ -0,0 +1,29 @@ +# Generated by Django 5.1.1 on 2024-10-12 05:31 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0014_alter_cart_options_remove_payment_fail_reason'), + ] + + operations = [ + migrations.AlterModelOptions( + name='cart', + options={'ordering': ['-created_at'], 'verbose_name': '장바구니', 'verbose_name_plural': '장바구니 목록'}, + ), + migrations.AlterModelOptions( + name='order', + options={'ordering': ['-created_at'], 'verbose_name': '주문', 'verbose_name_plural': '주문 목록'}, + ), + migrations.AlterModelOptions( + name='payment', + options={'ordering': ['-created_at'], 'verbose_name': '결제', 'verbose_name_plural': '결제 목록'}, + ), + migrations.AlterModelOptions( + name='userbillingaddress', + options={'ordering': ['-created_at'], 'verbose_name': '사용자 청구 주소', 'verbose_name_plural': '사용자 청구 주소 목록'}, + ), + ] diff --git a/payments/mixins.py b/payments/mixins.py new file mode 100644 index 0000000..63553f5 --- /dev/null +++ b/payments/mixins.py @@ -0,0 +1,306 @@ +from django.db import transaction +from django.utils import timezone +from rest_framework import status +from rest_framework.exceptions import NotFound, PermissionDenied, ValidationError +from rest_framework.response import Response + +from .models import Cart, CartItem, Order, Payment, UserBillingAddress +from .services import KakaoPayService + + +class GetObjectMixin: + def get_object_or_404(self, queryset, *filter_args, **filter_kwargs): + try: + obj = queryset.get(*filter_args, **filter_kwargs) + if hasattr(obj, "user") and obj.user != self.request.user: + raise PermissionDenied("이 객체에 접근할 권한이 없습니다.") + return obj + except queryset.model.DoesNotExist: + verbose_name = getattr( + queryset.model._meta, "verbose_name", queryset.model.__name__ + ) + raise NotFound(detail=f"{verbose_name}을(를) 찾을 수 없습니다.") + + +class CartMixin(GetObjectMixin): + def get_cart(self, user): + cart, _ = Cart.objects.get_or_create(user=user) + return cart + + def get_cart_item(self, cart, **kwargs): + return self.get_object_or_404(CartItem.objects.filter(cart=cart), **kwargs) + + @transaction.atomic + def add_to_cart(self, cart, serializer): + existing_item = CartItem.objects.filter( + cart=cart, + curriculum=serializer.validated_data.get("curriculum"), + course=serializer.validated_data.get("course"), + ).first() + + if existing_item: + return Response( + {"detail": "이 상품은 이미 장바구니에 있습니다."}, + status=status.HTTP_400_BAD_REQUEST, + ) + else: + serializer.save(cart=cart) + return Response( + { + "detail": "상품이 장바구니에 추가되었습니다.", + "data": serializer.data, + }, + status=status.HTTP_201_CREATED, + ) + + def remove_from_cart(self, cart_item): + cart_item.delete() + return Response( + {"detail": "상품이 장바구니에서 삭제되었습니다."}, + status=status.HTTP_204_NO_CONTENT, + ) + + +class OrderMixin(GetObjectMixin): + def get_order(self, user, **kwargs): + return self.get_object_or_404(Order.objects.filter(user=user), **kwargs) + + def create_order_from_cart(self, user, cart): + if not cart.cart_items.exists(): + raise ValidationError("장바구니가 비어있습니다.") + + if cart.get_total_price() > 50000: + raise ValidationError("상품의 총 가격이 50,000원을 초과할 수 없습니다.") + + order_items = [ + { + "curriculum_id": item.curriculum.id if item.curriculum else None, + "course_id": item.course.id if item.course else None, + "quantity": item.quantity, + "price": item.get_price(), + } + for item in cart.cart_items.all() + ] + + return { + "user_id": user.id, + "order_status": "pending", + "order_items": order_items, + } + + def create_new_order(self, user, order_data): + if not order_data.get("order_items"): + raise ValidationError("주문 항목이 없습니다.") + + return { + "user_id": user.id, + "order_status": "pending", + "order_items": order_data.get("order_items", []), + } + + +class UserBillingAddressMixin(GetObjectMixin): + def get_billing_address(self, user, **kwargs): + return self.get_object_or_404( + UserBillingAddress.objects.filter(user=user), **kwargs + ) + + def create_billing_address(self, user, serializer): + instance = serializer.save(user=user, is_default=True) + UserBillingAddress.objects.filter(user=user, is_default=True).exclude( + pk=instance.pk + ).update(is_default=False) + return Response( + {"detail": "청구 주소가 생성되었습니다."}, status=status.HTTP_201_CREATED + ) + + def update_billing_address(self, instance, serializer): + serializer.save() + return Response( + {"detail": "청구 주소가 수정되었습니다.", "data": serializer.data}, + status=status.HTTP_200_OK, + ) + + def delete_billing_address(self, instance): + instance.delete() + return Response( + {"detail": "청구 주소가 삭제되었습니다."}, + status=status.HTTP_204_NO_CONTENT, + ) + + +class PaymentMixin(GetObjectMixin): + kakao_pay_service = KakaoPayService() + + def get_payment(self, user, select_for_update=False, **kwargs): + queryset = Payment.objects.filter(user=user) + if select_for_update: + queryset = queryset.select_for_update() + return self.get_object_or_404(queryset, **kwargs) + + def validate_order(self, order): + if order.order_status != "pending": + raise ValidationError("결제 가능한 상태의 주문이 아닙니다.") + if order.get_total_price() > 50000: + raise ValidationError("결제 금액이 50,000원을 초과할 수 없습니다.") + + def create_payment(self, order, user): + self.validate_order(order) + + existing_payments = Payment.objects.filter( + order=order, payment_status="pending" + ) + if existing_payments.exists(): + # 모든 기존 pending payment를 취소 처리 + existing_payments.update(payment_status="cancelled") + + try: + kakao_response = self.kakao_pay_service.request_payment(order) + except Exception as e: + raise ValidationError( + "결제 요청 중 오류가 발생했습니다. 잠시 후 다시 시도해 주세요." + ) + + billing_address = UserBillingAddress.objects.filter( + user=user, is_default=True + ).first() + + payment = Payment.objects.create( + order=order, + user=user, + payment_status="pending", + amount=order.get_total_price(), + transaction_id=kakao_response["tid"], + billing_address=billing_address, + ) + + return payment, kakao_response + + def process_payment(self, order, payment, pg_token): + try: + self.kakao_pay_service.approve_payment(payment, pg_token) + except Exception as e: + payment.payment_status = "failed" + payment.save() + raise ValidationError( + "결제 승인 중 오류가 발생했습니다. 고객센터로 문의해 주세요." + ) + + payment.payment_status = "completed" + payment.paid_at = timezone.now() + payment.save() + order.order_status = "completed" + order.save() + + for order_item in order.order_items.all(): + order_item.save() + + def cancel_payment(self, order, payment): + payment.payment_status = "cancelled" + payment.cancelled_at = timezone.now() + payment.save() + order.order_status = "cancelled" + order.save() + + def fail_payment(self, payment): + payment.payment_status = "failed" + payment.save() + + def refund_payment(self, order, payment): + if order.order_status != "completed": + raise ValidationError("완료된 주문만 취소할 수 있습니다.") + + if not payment or payment.payment_status != "completed": + raise ValidationError("해당 주문에 대한 완료된 결제를 찾을 수 없습니다.") + + if payment.paid_at is None: + raise ValidationError("결제 완료 시간이 기록되지 않았습니다.") + + if timezone.now() - payment.paid_at > timezone.timedelta(days=7): + raise ValidationError("결제 후 7일이 지난 주문은 환불할 수 없습니다.") + + try: + self.kakao_pay_service.refund_payment(payment) + except Exception as e: + raise ValidationError( + "결제 취소 중 오류가 발생했습니다. 고객센터로 문의해 주세요." + ) + + payment.payment_status = "refunded" + payment.cancelled_at = timezone.now() + payment.save() + order.order_status = "refunded" + order.save() + + for order_item in order.order_items.all(): + order_item.expiry_date = None + order_item.save() + + +class ReceiptMixin(GetObjectMixin): + def get_receipt_list(self, user): + payments = Payment.objects.filter( + user=user, payment_status__in=["completed", "refunded"] + ).order_by("-paid_at") + + receipt_list = [ + { + "receipt_number": f"REC-{payment.id}", + "payment_status": payment.payment_status, + "amount": payment.amount, + "paid_at": ( + payment.paid_at.strftime("%Y-%m-%d %H:%M:%S") + if payment.paid_at + else None + ), + "order_id": payment.order.id, + } + for payment in payments + ] + + return receipt_list + + def get_receipt_detail(self, payment, user): + order = payment.order + billing_address = payment.billing_address + + receipt_data = { + "receipt_number": f"REC-{payment.id}", + "issue_date": timezone.now().strftime("%Y-%m-%d %H:%M:%S"), + "payment_info": { + "payment_id": payment.id, + "amount": payment.amount, + "payment_status": payment.payment_status, + "paid_at": ( + payment.paid_at.strftime("%Y-%m-%d %H:%M:%S") + if payment.paid_at + else None + ), + }, + "order_info": { + "order_id": order.id, + "order_status": order.order_status, + "total_items": order.get_total_items(), + "total_price": order.get_total_price(), + "items": [ + { + "name": item.get_item_name(), + "quantity": item.quantity, + "price": item.get_price(), + } + for item in order.order_items.all() + ], + }, + "customer_info": {"email": user.email}, + "billing_address": None, + } + + if billing_address: + receipt_data["billing_address"] = { + "country": billing_address.country, + "main_address": billing_address.main_address, + "detail_address": billing_address.detail_address, + "postal_code": billing_address.postal_code, + } + + return receipt_data diff --git a/payments/models.py b/payments/models.py index 802aa00..f9a97ab 100644 --- a/payments/models.py +++ b/payments/models.py @@ -1,4 +1,5 @@ from django.db import models +from django.utils import timezone from django.conf import settings from courses.models import Curriculum, Course @@ -6,15 +7,48 @@ class Cart(models.Model): """ - 유저의 장바구니 모델입니다. + 사용자의 장바구니 모델입니다. """ - user = models.OneToOneField(settings.AUTH_USER_MODEL, on_delete=models.CASCADE) - created_at = models.DateTimeField(auto_now_add=True) - updated_at = models.DateTimeField(auto_now=True) + user = models.OneToOneField( + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, verbose_name="사용자" + ) + created_at = models.DateTimeField(auto_now_add=True, verbose_name="생성일") + updated_at = models.DateTimeField(auto_now=True, verbose_name="수정일") + + def get_total_items(self): + return ( + self.cart_items.aggregate(total_items=models.Sum("quantity"))["total_items"] + or 0 + ) + + def get_total_price(self): + return ( + self.cart_items.aggregate( + total_price=models.Sum( + models.F("quantity") + * models.Case( + models.When( + curriculum__isnull=False, then=models.F("curriculum__price") + ), + models.When( + course__isnull=False, then=models.F("course__price") + ), + default=0, + output_field=models.DecimalField(), + ) + ) + )["total_price"] + or 0 + ) + + class Meta: + ordering = ["-created_at"] + verbose_name = "장바구니" + verbose_name_plural = "장바구니 목록" def __str__(self): - return f"{self.user.email}의 장바구니" + return f"{self.user.nickname}의 장바구니" class CartItem(models.Model): @@ -22,13 +56,302 @@ class CartItem(models.Model): 장바구니에 담긴 상품 모델입니다. """ - cart = models.ForeignKey(Cart, on_delete=models.CASCADE, related_name="cart_items") - curriculum = models.OneToOneField( - Curriculum, on_delete=models.CASCADE, null=True, blank=True + cart = models.ForeignKey( + Cart, + on_delete=models.CASCADE, + related_name="cart_items", + verbose_name="장바구니", + ) + curriculum = models.ForeignKey( + Curriculum, + on_delete=models.CASCADE, + null=True, + blank=True, + verbose_name="커리큘럼", + ) + course = models.ForeignKey( + Course, on_delete=models.CASCADE, null=True, blank=True, verbose_name="코스" + ) + quantity = models.PositiveIntegerField(default=1, verbose_name="수량") + created_at = models.DateTimeField(auto_now_add=True, verbose_name="생성일") + updated_at = models.DateTimeField(auto_now=True, verbose_name="수정일") + + def get_item_name(self): + if self.curriculum: + return self.curriculum.name + elif self.course: + return self.course.title + else: + return "알 수 없는 상품" + + def get_price(self): + unit_price = 0 + if self.curriculum: + unit_price = self.curriculum.price + elif self.course: + unit_price = self.course.price + return unit_price * self.quantity + + def get_image_url(self): + if self.curriculum and hasattr(self.curriculum, "images"): + return self.curriculum.images.url + elif self.course and hasattr(self.course, "images"): + return self.course.images.url + return None + + class Meta: + ordering = ["-created_at"] + verbose_name = "장바구니 상품" + verbose_name_plural = "장바구니 상품들" + + def __str__(self): + try: + user_nickname = self.cart.user.nickname + item_name = ( + self.curriculum.name + if self.curriculum + else self.course.title if self.course else "알 수 없는 상품" + ) + return f"{user_nickname}의 장바구니에 있는 {item_name}" + except Exception as e: + return f"장바구니 상품 (ID: {self.id})" + + +class Order(models.Model): + """ + 주문 모델입니다. + """ + + class Status(models.TextChoices): + PENDING = "pending", "대기 중" + COMPLETED = "completed", "완료됨" + FAILED = "failed", "실패함" + CANCELLED = "cancelled", "취소됨" + REFUNDED = "refunded", "환불됨" + + user = models.ForeignKey( + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, verbose_name="사용자" + ) + order_status = models.CharField( + max_length=10, + choices=Status.choices, + default=Status.PENDING, + verbose_name="주문 상태", + ) + created_at = models.DateTimeField(auto_now_add=True, verbose_name="생성일") + updated_at = models.DateTimeField(auto_now=True, verbose_name="수정일") + + def get_total_items(self): + return ( + self.order_items.aggregate(total_items=models.Sum("quantity"))[ + "total_items" + ] + or 0 + ) + + def get_total_price(self): + return ( + self.order_items.aggregate( + total_price=models.Sum( + models.F("quantity") + * models.Case( + models.When( + curriculum__isnull=False, then=models.F("curriculum__price") + ), + models.When( + course__isnull=False, then=models.F("course__price") + ), + default=0, + output_field=models.DecimalField(), + ) + ) + )["total_price"] + or 0 + ) + + class Meta: + ordering = ["-created_at"] + verbose_name = "주문" + verbose_name_plural = "주문 목록" + + def __str__(self): + return f"{self.user.nickname}의 주문" + + +class OrderItem(models.Model): + """ + 주문 상품 모델입니다. + """ + + order = models.ForeignKey( + Order, on_delete=models.CASCADE, related_name="order_items", verbose_name="주문" + ) + curriculum = models.ForeignKey( + "courses.Curriculum", + on_delete=models.SET_NULL, + null=True, + blank=True, + verbose_name="커리큘럼", + ) + course = models.ForeignKey( + "courses.Course", + on_delete=models.SET_NULL, + null=True, + blank=True, + verbose_name="코스", + ) + quantity = models.PositiveIntegerField(default=1, verbose_name="수량") + created_at = models.DateTimeField(auto_now_add=True, verbose_name="생성일") + updated_at = models.DateTimeField(auto_now=True, verbose_name="수정일") + expiry_date = models.DateTimeField(null=True, blank=True, verbose_name="만료일") + + def get_item_name(self): + if self.curriculum: + return self.curriculum.name + elif self.course: + return self.course.title + else: + return "알 수 없는 상품" + + def get_price(self): + unit_price = 0 + if self.curriculum: + unit_price = self.curriculum.price + elif self.course: + unit_price = self.course.price + return unit_price * self.quantity + + def get_image_url(self): + if self.curriculum and hasattr(self.curriculum, "images"): + return self.curriculum.images.url + elif self.course and hasattr(self.course, "images"): + return self.course.images.url + return None + + def set_expiry_date(self): + self.expiry_date = timezone.now() + timezone.timedelta(days=730) # 2년 + + def save(self, *args, **kwargs): + if self.order.order_status == "completed" and ( + not self.expiry_date or self.expiry_date < timezone.now() + ): + self.set_expiry_date() + super().save(*args, **kwargs) + + class Meta: + ordering = ["-created_at"] + verbose_name = "주문 상품" + verbose_name_plural = "주문 상품들" + + def __str__(self): + try: + user_nickname = self.order.user.nickname + item_name = ( + self.curriculum.name + if self.curriculum + else self.course.title if self.course else "알 수 없는 상품" + ) + return f"{user_nickname}의 주문에 있는 {item_name}" + except Exception as e: + return f"주문 상품 (ID: {self.id})" + + +class UserBillingAddress(models.Model): + """ + 사용자 청구 주소 모델입니다. + """ + + user = models.ForeignKey( + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, verbose_name="사용자" + ) + country = models.CharField(max_length=50, verbose_name="국가") + main_address = models.CharField(max_length=255, verbose_name="주소") + detail_address = models.CharField( + max_length=255, verbose_name="상세 주소", blank=True ) - quantity = models.PositiveIntegerField(default=1) - created_at = models.DateTimeField(auto_now_add=True) - updated_at = models.DateTimeField(auto_now=True) + postal_code = models.CharField(max_length=20, verbose_name="우편번호") + is_default = models.BooleanField(default=False, verbose_name="기본 주소") + created_at = models.DateTimeField(auto_now_add=True, verbose_name="생성일") + updated_at = models.DateTimeField(auto_now=True, verbose_name="수정일") + + class Meta: + ordering = ["-created_at"] + verbose_name = "사용자 청구 주소" + verbose_name_plural = "사용자 청구 주소 목록" def __str__(self): - return f"{self.cart.user.email}의 장바구니에 담긴 {self.curriculum.name}" + return f"{self.user.nickname}의 청구 주소" + + def save(self, *args, **kwargs): + # 기본 주소가 설정되어 있으면, 사용자의 다른 기본 주소를 해제합니다. + if self.is_default: + UserBillingAddress.objects.filter(user=self.user, is_default=True).update( + is_default=False + ) + super().save(*args, **kwargs) + + +class Payment(models.Model): + """ + 결제 모델입니다. + """ + + class Status(models.TextChoices): + PENDING = "pending", "대기 중" + COMPLETED = "completed", "완료됨" + FAILED = "failed", "실패함" + CANCELLED = "cancelled", "취소됨" + REFUNDED = "refunded", "환불됨" + + user = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + related_name="payments", + verbose_name="사용자", + ) + order = models.ForeignKey( + Order, on_delete=models.CASCADE, related_name="payments", verbose_name="주문" + ) + payment_status = models.CharField( + max_length=10, + choices=Status.choices, + default=Status.PENDING, + verbose_name="결제 상태", + ) + amount = models.PositiveIntegerField(verbose_name="결제 금액") + transaction_id = models.CharField( + max_length=255, verbose_name="거래 ID", blank=True, null=True + ) + created_at = models.DateTimeField( + auto_now_add=True, blank=True, null=True, verbose_name="생성일" + ) + updated_at = models.DateTimeField(auto_now=True, verbose_name="수정일") + paid_at = models.DateTimeField(null=True, blank=True, verbose_name="결제 일시") + cancelled_at = models.DateTimeField(null=True, blank=True, verbose_name="취소 일시") + billing_address = models.ForeignKey( + UserBillingAddress, + on_delete=models.SET_NULL, + null=True, + blank=True, + verbose_name="청구 주소", + ) + metadata = models.JSONField(default=dict, blank=True, verbose_name="메타데이터") + + class Meta: + ordering = ["-created_at"] + verbose_name = "결제" + verbose_name_plural = "결제 목록" + indexes = [ + models.Index(fields=["order", "payment_status"]), + models.Index(fields=["transaction_id"]), + ] + + def __str__(self): + return f"{self.user.nickname}의 결제 ({self.get_payment_status_display()})" + + def save(self, *args, **kwargs): + if not self.billing_address: + self.billing_address = UserBillingAddress.objects.filter( + user=self.user, is_default=True + ).first() + super().save(*args, **kwargs) diff --git a/payments/permissions.py b/payments/permissions.py new file mode 100644 index 0000000..5a64b13 --- /dev/null +++ b/payments/permissions.py @@ -0,0 +1,11 @@ +from rest_framework import permissions + + +class IsOwnerPermission(permissions.BasePermission): + def has_permission(self, request, view): + # 목록 조회나 생성 등의 작업에 대한 권한 검사 + return request.user.is_authenticated + + def has_object_permission(self, request, view, obj): + # 객체에 대한 권한 검사 + return obj.user == request.user diff --git a/payments/serializers.py b/payments/serializers.py index f9d5666..90b297e 100644 --- a/payments/serializers.py +++ b/payments/serializers.py @@ -1,34 +1,194 @@ from rest_framework import serializers -from .models import Cart, CartItem +from .models import Cart, CartItem, Order, OrderItem, Payment, UserBillingAddress class CartItemSerializer(serializers.ModelSerializer): """ - 상품 모델의 시리얼라이저입니다. 커리큘럼 이름을 포함합니다. + 장바구니에 담긴 상품 모델의 시리얼라이저입니다. """ - curriculum_name = serializers.SerializerMethodField() - class Meta: model = CartItem - fields = "__all__" + fields = [ + "id", + "cart", + "curriculum", + "course", + "quantity", + "created_at", + "updated_at", + "get_item_name", + "get_price", + "get_image_url", + ] + read_only_fields = [ + "id", + "cart", + "quantity", + "created_at", + "updated_at", + ] + + def validate(self, data): + # 커리큘럼과 코스 중 하나만 선택되었는지 확인합니다. + curriculum = data.get("curriculum") + course = data.get("course") + if not curriculum and not course: + raise serializers.ValidationError( + "커리큘럼 또는 코스 중 하나를 선택해야 합니다." + ) + if curriculum and course: + raise serializers.ValidationError( + "커리큘럼과 코스 중 하나만 선택해야 합니다." + ) + return data class CartSerializer(serializers.ModelSerializer): """ - 장바구니 모델의 시리얼라이저입니다. 상품의 수량 필드를 포함합니다. + 장바구니 모델의 시리얼라이저입니다. 총 상품 수량과 가격을 계산합니다. """ - items = CartItemSerializer(many=True, read_only=True) - total_items = serializers.SerializerMethodField() + cart_items = CartItemSerializer(many=True, read_only=True) class Meta: model = Cart - fields = "__all__" + fields = [ + "id", + "user", + "cart_items", + "get_total_items", + "get_total_price", + "created_at", + "updated_at", + ] + read_only_fields = [ + "id", + "created_at", + "updated_at", + ] + + +class OrderItemSerializer(serializers.ModelSerializer): + """ + 주문 상품 모델의 시리얼라이저입니다. + """ + + class Meta: + model = OrderItem + fields = [ + "id", + "order", + "curriculum", + "course", + "quantity", + "created_at", + "updated_at", + "expiry_date", + "get_item_name", + "get_price", + "get_image_url", + ] + read_only_fields = [ + "id", + "quantity", + "created_at", + "updated_at", + ] + + def validate(self, data): + # 커리큘럼과 코스 중 하나만 선택되었는지 확인합니다. + curriculum = data.get("curriculum") + course = data.get("course") + if not curriculum and not course: + raise serializers.ValidationError( + "커리큘럼 또는 코스 중 하나를 선택해야 합니다." + ) + if curriculum and course: + raise serializers.ValidationError( + "커리큘럼과 코스 중 하나만 선택해야 합니다." + ) + return data + + +class OrderSerializer(serializers.ModelSerializer): + """ + 주문 모델의 시리얼라이저입니다. 총 상품 수량과 가격을 계산합니다. + """ + + order_items = OrderItemSerializer(many=True, read_only=True) - def get_total_items(self, obj): - """ - 장바구니에 담긴 아이템의 총 개수를 반환합니다. - """ - return sum(item.quantity for item in obj.items.all()) + class Meta: + model = Order + fields = [ + "id", + "user", + "order_items", + "order_status", + "created_at", + "updated_at", + "get_total_items", + "get_total_price", + ] + read_only_fields = [ + "id", + "created_at", + "updated_at", + ] + + +class UserBillingAddressSerializer(serializers.ModelSerializer): + """ + 사용자의 결제 수단 모델의 시리얼라이저입니다. + """ + + class Meta: + model = UserBillingAddress + fields = [ + "id", + "user", + "country", + "main_address", + "detail_address", + "postal_code", + "is_default", + "created_at", + "updated_at", + ] + read_only_fields = ["id", "user", "created_at", "updated_at"] + + +class PaymentSerializer(serializers.ModelSerializer): + """ + 결제 모델의 시리얼라이저입니다. + """ + + class Meta: + model = Payment + fields = [ + "id", + "user", + "order", + "payment_status", + "amount", + "transaction_id", + "created_at", + "updated_at", + "paid_at", + "cancelled_at", + "billing_address", + ] + read_only_fields = [ + "id", + "user", + "order", + "payment_status", + "payment_method", + "amount", + "transaction_id", + "created_at", + "updated_at", + "paid_at", + "cancelled_at", + ] diff --git a/payments/services.py b/payments/services.py new file mode 100644 index 0000000..e62ca8b --- /dev/null +++ b/payments/services.py @@ -0,0 +1,82 @@ +import requests +from django.conf import settings + + +class KakaoPayService: + """ + 카카오페이 결제 서비스를 처리하는 클래스입니다. + """ + + def request_payment(self, order): + """ + 주어진 주문에 대해 카카오페이 결제 요청을 보냅니다. + """ + url = "https://open-api.kakaopay.com/online/v1/payment/ready" + headers = self._get_headers() + base_url = settings.BASE_URL.strip("'").split("#")[0].strip() + + payment_request = { + "cid": settings.KAKAOPAY_CID, + "partner_order_id": str(order.id), + "partner_user_id": str(order.user.id), + "item_name": f"Order #{order.id}", + "quantity": order.get_total_items(), + "total_amount": order.get_total_price(), + "tax_free_amount": 0, + "approval_url": f"{base_url}/api/payments/{order.id}/?result=success", + "cancel_url": f"{base_url}/api/payments/{order.id}/?result=cancel", + "fail_url": f"{base_url}/api/payments/{order.id}/?result=fail", + } + + response = requests.post(url, json=payment_request, headers=headers) + if response.status_code != 200: + raise Exception(f"카카오페이 결제 요청 실패: {response.text}") + + return response.json() + + def approve_payment(self, payment, pg_token): + """ + 주어진 주문에 대해 카카오페이 결제를 승인합니다. + """ + url = "https://open-api.kakaopay.com/online/v1/payment/approve" + headers = self._get_headers() + + approval_request = { + "cid": settings.KAKAOPAY_CID, + "tid": payment.transaction_id, + "partner_order_id": str(payment.order.id), + "partner_user_id": str(payment.user.id), + "pg_token": pg_token, + } + + response = requests.post(url, json=approval_request, headers=headers) + if response.status_code != 200: + raise Exception("카카오페이 결제 승인 실패") + + return response.json() + + def refund_payment(self, payment): + """ + 주어진 주문에 대해 카카오페이 결제를 환불합니다. + """ + url = "https://open-api.kakaopay.com/online/v1/payment/cancel" + headers = self._get_headers() + + refund_request = { + "cid": settings.KAKAOPAY_CID, + "tid": payment.transaction_id, + "cancel_amount": payment.amount, + "cancel_tax_free_amount": 0, + } + + response = requests.post(url, json=refund_request, headers=headers) + if response.status_code != 200: + raise Exception("카카오페이 결제 환불 실패") + + return response.json() + + def _get_headers(self): + return { + "Authorization": f"SECRET_KEY {settings.KAKAOPAY_SECRET_KEY}", + "Content-Type": "application/json", + } diff --git a/payments/tests.py b/payments/tests.py deleted file mode 100644 index 7ce503c..0000000 --- a/payments/tests.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.test import TestCase - -# Create your tests here. diff --git a/payments/tests/conftest.py b/payments/tests/conftest.py new file mode 100644 index 0000000..a3a950f --- /dev/null +++ b/payments/tests/conftest.py @@ -0,0 +1,164 @@ +from unittest.mock import MagicMock + +import pytest +from django.conf import settings +from django.contrib.auth import get_user_model +from django.utils import timezone +from rest_framework.test import APIClient + +from courses.models import Course, Curriculum +from payments.models import ( + Cart, + CartItem, + Order, + OrderItem, + Payment, + UserBillingAddress, +) +from payments.services import KakaoPayService + + +@pytest.fixture +def api_client(): + return APIClient() + + +@pytest.fixture +def user(): + User = get_user_model() + return User.objects.create_user( + email="test@example.com", password="testpass123", nickname="testnick" + ) + + +@pytest.fixture +def staff_user(): + User = get_user_model() + return User.objects.create_user( + email="staff@example.com", + password="staffpass123", + is_staff=True, + nickname="staffnick", + ) + + +@pytest.fixture +def cart(user): + return Cart.objects.create(user=user) + + +@pytest.fixture +def course(staff_user): + return Course.objects.create( + title="Test Course", + author=staff_user, + price=10000, + description="This is a test course description", + ) + + +@pytest.fixture +def curriculum(staff_user): + return Curriculum.objects.create( + name="Test Curriculum", price=20000, author=staff_user + ) + + +@pytest.fixture +def cart_item(cart, course): + return CartItem.objects.create(cart=cart, course=course, quantity=1) + + +@pytest.fixture +def order(user): + return Order.objects.create(user=user, order_status="pending") + + +@pytest.fixture +def completed_order(user): + return Order.objects.create(user=user, order_status="completed") + + +@pytest.fixture +def order_item(order, course): + return OrderItem.objects.create(order=order, course=course, quantity=1) + + +@pytest.fixture +def user_billing_address(user): + return UserBillingAddress.objects.create( + user=user, + country="대한민국", + main_address="서울특별시 강남구", + detail_address="테헤란로 123", + postal_code="06234", + is_default=True, + ) + + +@pytest.fixture +def non_default_billing_address(db, user): + return UserBillingAddress.objects.create( + user=user, + country="KR", + main_address="부산시", + detail_address="해운대구", + postal_code="48099", + is_default=False, + ) + + +@pytest.fixture +def payment(user, order, user_billing_address): + return Payment.objects.create( + user=user, + order=order, + payment_status="pending", + amount=10000, + transaction_id="test_transaction", + billing_address=user_billing_address, + ) + + +@pytest.fixture +def completed_payment(user, completed_order, user_billing_address): + return Payment.objects.create( + user=user, + order=completed_order, + payment_status="completed", + amount=10000, + transaction_id="test_transaction", + billing_address=user_billing_address, + ) + + +@pytest.fixture +def completed_payment_with_time(user, completed_order, user_billing_address): + return Payment.objects.create( + user=user, + order=completed_order, + payment_status="completed", + amount=10000, + transaction_id="test_transaction_with_time", + billing_address=user_billing_address, + paid_at=timezone.now(), + ) + + +@pytest.fixture +def mock_kakao_pay_settings(settings): + settings.KAKAOPAY_SECRET_KEY = "test_secret_key" + settings.KAKAOPAY_CID = "test_cid" + settings.BASE_URL = "http://testserver" + + +@pytest.fixture +def mock_kakao_pay_service(): + return MagicMock(spec=KakaoPayService) + + +@pytest.fixture +def mock_request(): + request = MagicMock() + request.user = user() + return request diff --git a/payments/tests/test_payments_mixins_and_services.py b/payments/tests/test_payments_mixins_and_services.py new file mode 100644 index 0000000..2f42f82 --- /dev/null +++ b/payments/tests/test_payments_mixins_and_services.py @@ -0,0 +1,202 @@ +import pytest +from django.utils import timezone +from django.contrib.auth import get_user_model +from rest_framework.exceptions import NotFound, ValidationError, PermissionDenied +from unittest.mock import MagicMock, patch +from payments.mixins import ( + GetObjectMixin, + CartMixin, + OrderMixin, + UserBillingAddressMixin, + PaymentMixin, +) +from payments.services import KakaoPayService +from courses.models import Course +from payments.models import Order, UserBillingAddress + + +class TestGetObjectMixin: + @pytest.fixture + def mixin(self): + return GetObjectMixin() + + @pytest.mark.django_db + def test_get_object_or_404_성공(self, mixin, user, course): + obj = mixin.get_object_or_404(Course.objects.all(), id=course.id) + assert obj == course + + @pytest.mark.django_db + def test_get_object_or_404_실패_객체없음(self, mixin): + with pytest.raises(NotFound): + mixin.get_object_or_404(Course.objects.all(), id=9999) + + @pytest.mark.django_db + def test_get_object_or_404_실패_권한없음(self, mixin, user, order): + mixin.request = MagicMock(user=user) + other_user = get_user_model().objects.create_user( + email="other@example.com", password="pass", nickname="othernick" + ) + order.user = other_user + order.save() + with pytest.raises(PermissionDenied): + mixin.get_object_or_404(Order.objects.all(), id=order.id) + + +class TestCartMixin: + @pytest.fixture + def mixin(self): + return CartMixin() + + @pytest.mark.django_db + def test_get_cart_성공(self, mixin, user, cart): + assert mixin.get_cart(user) == cart + + @pytest.mark.django_db + def test_get_cart_item_성공(self, mixin, cart, cart_item): + assert mixin.get_cart_item(cart, id=cart_item.id) == cart_item + + @pytest.mark.django_db + def test_add_to_cart_성공(self, mixin, cart, course): + serializer = MagicMock() + serializer.validated_data = {"course": course} + serializer.data = {"id": 1, "course": course.id} + response = mixin.add_to_cart(cart, serializer) + assert response.status_code == 201 + assert "상품이 장바구니에 추가되었습니다" in response.data["detail"] + + @pytest.mark.django_db + def test_remove_from_cart_성공(self, mixin, cart_item): + response = mixin.remove_from_cart(cart_item) + assert response.status_code == 204 + + +class TestOrderMixin: + @pytest.fixture + def mixin(self): + mixin = OrderMixin() + mixin.request = MagicMock() + return mixin + + @pytest.mark.django_db + def test_get_order_성공(self, mixin, user, order): + mixin.request.user = user + assert mixin.get_order(user, id=order.id) == order + + @pytest.mark.django_db + def test_create_order_from_cart_성공(self, mixin, user, cart, cart_item): + order_data = mixin.create_order_from_cart(user, cart) + assert order_data["user_id"] == user.id + assert len(order_data["order_items"]) == 1 + + @pytest.mark.django_db + def test_create_order_from_cart_실패_장바구니_비어있음(self, mixin, user, cart): + with pytest.raises(ValidationError): + mixin.create_order_from_cart(user, cart) + + +class TestUserBillingAddressMixin: + @pytest.fixture + def mixin(self): + mixin = UserBillingAddressMixin() + mixin.request = MagicMock() + return mixin + + @pytest.mark.django_db + def test_get_billing_address_성공(self, mixin, user, user_billing_address): + mixin.request.user = user + assert ( + mixin.get_billing_address(user, id=user_billing_address.id) + == user_billing_address + ) + + @pytest.mark.django_db + def test_create_billing_address_성공(self, mixin, user): + serializer = MagicMock() + serializer.save.return_value = UserBillingAddress(id=1, user=user) + response = mixin.create_billing_address(user, serializer) + assert response.status_code == 201 + assert "청구 주소가 생성되었습니다" in response.data["detail"] + + +class TestPaymentMixin: + @pytest.fixture + def mixin(self): + mixin = PaymentMixin() + mixin.request = MagicMock() + return mixin + + @pytest.mark.django_db + def test_get_payment_성공(self, mixin, user, payment): + mixin.request.user = user + assert mixin.get_payment(user, id=payment.id) == payment + + @pytest.mark.django_db + def test_validate_order_성공(self, mixin, order): + mixin.validate_order(order) # 예외가 발생하지 않아야 함 + + @pytest.mark.django_db + def test_validate_order_실패_주문상태_부적절(self, mixin, completed_order): + with pytest.raises(ValidationError): + mixin.validate_order(completed_order) + + @pytest.mark.django_db + def test_create_payment_성공(self, mixin, order, user, mock_kakao_pay_service): + mixin.kakao_pay_service = mock_kakao_pay_service + mock_kakao_pay_service.request_payment.return_value = {"tid": "test_tid"} + payment, kakao_response = mixin.create_payment(order, user) + assert payment.order == order + assert payment.user == user + assert payment.payment_status == "pending" + + @pytest.mark.django_db + def test_process_payment_성공(self, mixin, order, payment, mock_kakao_pay_service): + mixin.kakao_pay_service = mock_kakao_pay_service + mixin.process_payment(order, payment, "test_pg_token") + assert payment.payment_status == "completed" + assert order.order_status == "completed" + + @pytest.mark.django_db + def test_cancel_payment_성공(self, mixin, order, payment): + mixin.cancel_payment(order, payment) + assert payment.payment_status == "cancelled" + assert order.order_status == "cancelled" + + @pytest.mark.django_db + def test_refund_payment_성공( + self, mixin, completed_order, completed_payment, mock_kakao_pay_service + ): + mixin.kakao_pay_service = mock_kakao_pay_service + completed_payment.paid_at = timezone.now() # 날짜 설정 + mixin.refund_payment(completed_order, completed_payment) + assert completed_payment.payment_status == "refunded" + assert completed_order.order_status == "refunded" + + +class TestKakaoPayService: + @pytest.fixture + def service(self): + return KakaoPayService() + + @pytest.mark.django_db + def test_request_payment_성공(self, service, order, mock_kakao_pay_settings): + with patch("requests.post") as mock_post: + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = {"tid": "test_tid"} + response = service.request_payment(order) + assert response["tid"] == "test_tid" + + @pytest.mark.django_db + def test_approve_payment_성공(self, service, payment, mock_kakao_pay_settings): + with patch("requests.post") as mock_post: + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = {"amount": {"total": 10000}} + response = service.approve_payment(payment, "test_pg_token") + assert response["amount"]["total"] == 10000 + + @pytest.mark.django_db + def test_refund_payment_성공(self, service, payment, mock_kakao_pay_settings): + with patch("requests.post") as mock_post: + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = {"amount": {"total": 10000}} + response = service.refund_payment(payment) + assert response["amount"]["total"] == 10000 diff --git a/payments/tests/test_payments_models.py b/payments/tests/test_payments_models.py new file mode 100644 index 0000000..ccd9d75 --- /dev/null +++ b/payments/tests/test_payments_models.py @@ -0,0 +1,117 @@ +import pytest +from payments.models import ( + Cart, + CartItem, + Order, + OrderItem, + Payment, + UserBillingAddress, +) + + +@pytest.mark.django_db +class TestCart: + def test_cart_생성(self, user): + cart = Cart.objects.create(user=user) + assert cart.user == user + assert cart.get_total_items() == 0 + assert cart.get_total_price() == 0 + + def test_cart_total_items_and_price(self, cart, course, curriculum): + CartItem.objects.create(cart=cart, course=course, quantity=1) + CartItem.objects.create(cart=cart, curriculum=curriculum, quantity=1) + assert cart.get_total_items() == 2 + assert cart.get_total_price() == course.price + curriculum.price + + +@pytest.mark.django_db +class TestOrder: + def test_order_생성(self, user): + order = Order.objects.create(user=user, order_status="pending") + assert order.user == user + assert order.order_status == "pending" + assert order.get_total_items() == 0 + assert order.get_total_price() == 0 + + def test_order_total_items_and_price(self, order, course, curriculum): + OrderItem.objects.create(order=order, course=course, quantity=1) + OrderItem.objects.create(order=order, curriculum=curriculum, quantity=1) + assert order.get_total_items() == 2 + assert order.get_total_price() == course.price + curriculum.price + + +@pytest.mark.django_db +class TestUserBillingAddress: + def test_user_billing_address_생성(self, user): + address = UserBillingAddress.objects.create( + user=user, + country="대한민국", + main_address="서울특별시 강남구", + detail_address="테헤란로 123", + postal_code="06234", + is_default=True, + ) + assert address.user == user + assert address.is_default == True + + def test_user_billing_address_default_설정(self, user): + address1 = UserBillingAddress.objects.create( + user=user, + country="대한민국", + main_address="서울특별시 강남구", + detail_address="테헤란로 123", + postal_code="06234", + is_default=True, + ) + address2 = UserBillingAddress.objects.create( + user=user, + country="대한민국", + main_address="서울특별시 서초구", + detail_address="반포대로 123", + postal_code="06548", + is_default=True, + ) + address1.refresh_from_db() + assert address1.is_default == False + assert address2.is_default == True + + +@pytest.mark.django_db +class TestPayment: + def test_payment_생성(self, user, order, user_billing_address): + payment = Payment.objects.create( + user=user, + order=order, + payment_status="pending", + amount=10000, + transaction_id="test_transaction", + billing_address=user_billing_address, + ) + assert payment.user == user + assert payment.order == order + assert payment.payment_status == "pending" + assert payment.amount == 10000 + assert payment.transaction_id == "test_transaction" + assert payment.billing_address == user_billing_address + + def test_payment_default_billing_address(self, user, order): + payment = Payment.objects.create( + user=user, + order=order, + payment_status="pending", + amount=10000, + transaction_id="test_transaction", + ) + assert payment.billing_address is None + + default_address = UserBillingAddress.objects.create( + user=user, + country="대한민국", + main_address="서울특별시 강남구", + detail_address="테헤란로 123", + postal_code="06234", + is_default=True, + ) + payment.save() + payment.refresh_from_db() + assert payment.billing_address == default_address diff --git a/payments/tests/test_payments_serializers.py b/payments/tests/test_payments_serializers.py new file mode 100644 index 0000000..253f7f1 --- /dev/null +++ b/payments/tests/test_payments_serializers.py @@ -0,0 +1,84 @@ +import pytest +from payments.serializers import ( + CartItemSerializer, + CartSerializer, + OrderItemSerializer, + OrderSerializer, + UserBillingAddressSerializer, + PaymentSerializer, +) +from payments.models import CartItem, OrderItem + + +@pytest.mark.django_db +class TestCartItemSerializer: + def test_cartitem_serializer_유효성검사(self, cart, course): + data = {"cart": cart.id, "course": course.id, "quantity": 1} + serializer = CartItemSerializer(data=data) + assert serializer.is_valid() + + def test_cartitem_serializer_유효성검사_실패(self, cart, course, curriculum): + data = { + "cart": cart.id, + "course": course.id, + "curriculum": curriculum.id, + "quantity": 1, + } + serializer = CartItemSerializer(data=data) + assert not serializer.is_valid() + assert "non_field_errors" in serializer.errors + + +@pytest.mark.django_db +class TestCartSerializer: + def test_cart_serializer(self, cart, course): + CartItem.objects.create(cart=cart, course=course, quantity=1) + serializer = CartSerializer(cart) + assert serializer.data["get_total_items"] == 1 + assert serializer.data["get_total_price"] == course.price + + +@pytest.mark.django_db +class TestOrderItemSerializer: + def test_orderitem_serializer_유효성검사(self, order, course): + data = {"order": order.id, "course": course.id, "quantity": 1} + serializer = OrderItemSerializer(data=data) + assert serializer.is_valid() + + def test_orderitem_serializer_유효성검사_실패(self, order, course, curriculum): + data = { + "order": order.id, + "course": course.id, + "curriculum": curriculum.id, + "quantity": 1, + } + serializer = OrderItemSerializer(data=data) + assert not serializer.is_valid() + assert "non_field_errors" in serializer.errors + + +@pytest.mark.django_db +class TestOrderSerializer: + def test_order_serializer(self, order, course): + OrderItem.objects.create(order=order, course=course, quantity=1) + serializer = OrderSerializer(order) + assert serializer.data["get_total_items"] == 1 + assert serializer.data["get_total_price"] == course.price + + +@pytest.mark.django_db +class TestUserBillingAddressSerializer: + def test_user_billing_address_serializer(self, user_billing_address): + serializer = UserBillingAddressSerializer(user_billing_address) + assert serializer.data["country"] == "대한민국" + assert serializer.data["main_address"] == "서울특별시 강남구" + assert serializer.data["is_default"] == True + + +@pytest.mark.django_db +class TestPaymentSerializer: + def test_payment_serializer(self, payment): + serializer = PaymentSerializer(payment) + assert serializer.data["payment_status"] == "pending" + assert serializer.data["amount"] == 10000 + assert serializer.data["transaction_id"] == "test_transaction" diff --git a/payments/tests/test_payments_views.py b/payments/tests/test_payments_views.py new file mode 100644 index 0000000..d56b3cb --- /dev/null +++ b/payments/tests/test_payments_views.py @@ -0,0 +1,117 @@ +import pytest +from django.utils import timezone +from django.urls import reverse +from rest_framework import status +from unittest.mock import patch + + +@pytest.mark.django_db +class Test장바구니뷰: + def test_장바구니_조회_성공(self, api_client, user, cart_item): + api_client.force_authenticate(user=user) + url = reverse("payments:cart-list-create") + response = api_client.get(url) + assert response.status_code == status.HTTP_200_OK + assert len(response.data["cart_items"]) == 1 + + def test_장바구니_아이템_삭제_성공(self, api_client, user, cart_item): + api_client.force_authenticate(user=user) + url = reverse("payments:cart-item-delete", args=[cart_item.id]) + response = api_client.delete(url) + assert response.status_code == status.HTTP_204_NO_CONTENT + + +@pytest.mark.django_db +class Test주문뷰: + def test_진행중인_주문_조회_성공(self, api_client, user, order): + api_client.force_authenticate(user=user) + url = reverse("payments:order") + response = api_client.get(url) + assert response.status_code == status.HTTP_200_OK + assert response.data["order_status"] == "pending" + + def test_주문_생성_실패_빈_장바구니(self, api_client, user): + api_client.force_authenticate(user=user) + url = reverse("payments:order") + data = {"from_cart": True} + response = api_client.post(url, data) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "장바구니가 비어 있습니다" in response.data["detail"] + + +@pytest.mark.django_db +class Test결제뷰: + @patch("payments.mixins.KakaoPayService.request_payment") + def test_결제_요청(self, mock_request_payment, api_client, user, order): + mock_request_payment.return_value = { + "next_redirect_pc_url": "http://test-redirect-url.com", + "next_redirect_mobile_url": "http://test-redirect-url.com", + "next_redirect_app_url": "http://test-redirect-url.com", + "tid": "test_transaction_id", + } + api_client.force_authenticate(user=user) + url = reverse("payments:payment") + response = api_client.post(url) + print(f"\n결제 요청 테스트") + print(f"URL: {url}") + print(f"응답 상태 코드: {response.status_code}") + print(f"응답 데이터: {response.data}") + assert response.status_code == status.HTTP_201_CREATED + assert "payment" in response.data + assert "next_redirect_pc_url" in response.data + + @patch("payments.mixins.KakaoPayService.approve_payment") + def test_결제_승인_성공(self, mock_approve_payment, api_client, user, payment): + mock_approve_payment.return_value = {"amount": {"total": 10000}} + api_client.force_authenticate(user=user) + url = reverse("payments:payment") + response = api_client.get(url, {"result": "success", "pg_token": "test_token"}) + print(f"\n결제 승인 테스트") + print(f"URL: {url}") + print(f"응답 상태 코드: {response.status_code}") + print(f"응답 데이터: {response.data}") + assert response.status_code == status.HTTP_200_OK + assert response.data["detail"] == "결제가 성공적으로 완료되었습니다." + + def test_결제_취소(self, api_client, user, payment): + api_client.force_authenticate(user=user) + url = reverse("payments:payment") + response = api_client.get(url, {"result": "cancel"}) + print(f"\n결제 취소 테스트") + print(f"URL: {url}") + print(f"응답 상태 코드: {response.status_code}") + print(f"응답 데이터: {response.data}") + assert response.status_code == status.HTTP_200_OK + assert "결제 과정이 취소되었습니다" in response.data["detail"] + + @patch("payments.mixins.KakaoPayService.refund_payment") + def test_결제_환불(self, mock_refund_payment, api_client, user, completed_payment): + mock_refund_payment.return_value = {"status": "CANCEL_PAYMENT"} + completed_payment.paid_at = timezone.now() + completed_payment.save() + api_client.force_authenticate(user=user) + url = reverse("payments:payment-cancel", args=[completed_payment.order.id]) + response = api_client.delete(url) + print(f"\n결제 환불 테스트") + print(f"URL: {url}") + print(f"응답 상태 코드: {response.status_code}") + print(f"응답 데이터: {response.data}") + assert response.status_code == status.HTTP_200_OK + assert "결제가 성공적으로 환불되었습니다" in response.data["detail"] + + +@pytest.mark.django_db +class Test영수증뷰: + def test_영수증_목록_조회_성공(self, api_client, user, completed_payment): + api_client.force_authenticate(user=user) + url = reverse("payments:receipt-list") + response = api_client.get(url) + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 1 + + def test_영수증_상세_조회_성공(self, api_client, user, completed_payment): + api_client.force_authenticate(user=user) + url = reverse("payments:receipt-detail", args=[completed_payment.id]) + response = api_client.get(url) + assert response.status_code == status.HTTP_200_OK + assert response.data["id"] == completed_payment.id diff --git a/payments/urls.py b/payments/urls.py index 910894e..0a1bd00 100644 --- a/payments/urls.py +++ b/payments/urls.py @@ -1,8 +1,47 @@ from django.urls import path +from .views import ( + CartView, + OrderView, + UserBillingAddressView, + PaymentView, + ReceiptView, +) -from .views import CartItemView +app_name = "payments" urlpatterns = [ - path("cart/items/", CartItemView.as_view(), name="cart-items"), - path("cart/items//", CartItemView.as_view(), name="cart-item-detail"), + # 장바구니 관련 URLs + path("cart/", CartView.as_view(), name="cart-list-create"), + path("cart//", CartView.as_view(), name="cart-item-delete"), + # 주문 관련 URLs + path("orders/", OrderView.as_view(), name="order"), + # 청구 주소 관련 URLs + path( + "billing-addresses/", + UserBillingAddressView.as_view(), + name="billing-address-list-create", + ), + path( + "billing-addresses//", + UserBillingAddressView.as_view(), + name="billing-address-detail", + ), + # 결제 관련 URLs + path( + "payments/", + PaymentView.as_view(), + name="payment", + ), + path( + "payments//cancel/", + PaymentView.as_view(), + name="payment-cancel", + ), + # 영수증 관련 URLs + path("receipts/", ReceiptView.as_view(), name="receipt-list"), + path( + "receipts//", + ReceiptView.as_view(), + name="receipt-detail", + ), ] diff --git a/payments/views.py b/payments/views.py index 489d453..2884ed5 100644 --- a/payments/views.py +++ b/payments/views.py @@ -1,107 +1,414 @@ -from rest_framework import status -from rest_framework.views import APIView +from django.core.exceptions import ValidationError +from django.db import transaction +from drf_spectacular.utils import extend_schema, extend_schema_view +from rest_framework import generics, status +from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response -from courses.models import Curriculum +from .mixins import ( + CartMixin, + OrderMixin, + PaymentMixin, + ReceiptMixin, + UserBillingAddressMixin, +) +from .models import CartItem, Order, Payment, UserBillingAddress +from .permissions import IsOwnerPermission +from .serializers import ( + CartItemSerializer, + CartSerializer, + OrderItemSerializer, + OrderSerializer, + PaymentSerializer, + UserBillingAddressSerializer, +) -from .models import Cart, CartItem -from .serializers import CartSerializer, CartItemSerializer +@extend_schema_view( + get=extend_schema( + summary="사용자의 장바구니를 조회하는 API", + description="사용자의 장바구니를 조회하거나 특정 상품을 조회합니다.", + responses={200: CartSerializer}, + ), + post=extend_schema( + summary="장바구니에 상품을 추가하는 API", + description="장바구니에 새로운 상품을 추가합니다.", + responses={201: CartItemSerializer}, + ), + delete=extend_schema( + summary="장바구니에서 상품을 삭제하는 API", + description="장바구니에서 특정 상품을 삭제합니다.", + responses={204: None}, + ), +) +class CartView(CartMixin, generics.GenericAPIView): + """ + 장바구니 관련 기능을 처리합니다. -class CartItemView(APIView): + [GET /cart/]: 사용자의 장바구니를 조회합니다. + [GET /cart/{cart_item_id}/]: 사용자의 장바구니에서 특정 상품을 조회합니다. + [POST /cart/]: 장바구니에 상품을 추가합니다. + [DELETE /cart/{cart_item_id}/]: 장바구니에서 특정 상품을 삭제합니다. + """ - def get_cart(self, user): - """ - 사용자별 장바구니를 조회합니다. 없으면 새로 생성합니다. - """ - cart, _ = Cart.objects.get_or_create(user=user) - return cart + serializer_class = CartItemSerializer + permission_classes = [IsAuthenticated] - def get_cart_item(self, cart, item_id): - """ - 사용자별 장바구니에 있는 특정 상품을 조회합니다. - """ + def get_queryset(self): + return CartItem.objects.filter(cart__user=self.request.user).select_related( + "cart", "cart__user" + ) + + def get(self, request, pk=None): + if pk: + cart_item = self.get_cart_item(self.get_cart(request.user), pk=pk) + serializer = self.get_serializer(cart_item) + return Response(serializer.data) + else: + cart = self.get_cart(request.user) + serializer = CartSerializer(cart) + return Response(serializer.data) + + def post(self, request): + cart = self.get_cart(request.user) + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + return self.add_to_cart(cart, serializer) + + def delete(self, request, pk): + cart_item = self.get_cart_item(self.get_cart(request.user), pk=pk) + return self.remove_from_cart(cart_item) + + +@extend_schema_view( + get=extend_schema( + summary="사용자의 진행 중인 주문을 조회하는 API", + description="사용자의 현재 진행 중인 (order_status=pending) 주문을 조회합니다.", + responses={200: OrderSerializer}, + ), + post=extend_schema( + summary="새로운 주문을 생성하는 API", + description="장바구니를 통해 주문을 생성하거나 직접 주문을 생성할 수 있습니다. 기존의 진행 중인 주문은 취소 처리됩니다.", + responses={201: OrderSerializer}, + ), +) +class OrderView(OrderMixin, CartMixin, generics.GenericAPIView): + """ + 주문 관련 기능을 처리합니다. + + [GET /orders/]: 사용자의 현재 진행 중인 (pending 상태의) 주문을 조회합니다. + [POST /orders/]: 새로운 주문을 생성합니다. + - from_cart=False: 직접 주문을 생성합니다. + - from_cart=True: 장바구니를 통해 주문을 생성합니다. + 주의: 새 주문 생성 시 기존의 진행 중인 주문은 자동으로 취소됩니다. + """ + + serializer_class = OrderSerializer + permission_classes = [IsAuthenticated] + + def get_queryset(self): + return Order.objects.filter(user=self.request.user, order_status="pending") + + def get(self, request): + pending_order = self.get_queryset().first() + if pending_order: + serializer = self.get_serializer(pending_order) + return Response(serializer.data) + else: + return Response( + {"detail": "현재 진행 중인 주문이 없습니다.", "data": None}, + status=status.HTTP_200_OK, + ) + + @transaction.atomic + def post(self, request): try: - return CartItem.objects.get(id=item_id, cart=cart) - except CartItem.DoesNotExist: + # 기존의 pending 상태 주문을 cancelled로 변경 + Order.objects.filter(user=request.user, order_status="pending").update( + order_status="cancelled" + ) + + if request.data.get("from_cart", False): + cart = self.get_cart(request.user) + if not cart.cart_items.exists(): + raise ValidationError("장바구니가 비어 있습니다.") + order_data = self.create_order_from_cart(request.user, cart) + else: + if "order_items" not in request.data or not request.data["order_items"]: + raise ValidationError("주문 항목이 없습니다.") + order_data = self.create_new_order(request.user, request.data) + + order_data["user"] = request.user.id + serializer = self.get_serializer(data=order_data) + serializer.is_valid(raise_exception=True) + order = serializer.save() + + for item_data in order_data["order_items"]: + item_data["order"] = order.id + order_item_serializer = OrderItemSerializer(data=item_data) + order_item_serializer.is_valid(raise_exception=True) + order_item_serializer.save() + + if request.data.get("from_cart", False): + cart = self.get_cart(request.user) + cart.cart_items.all().delete() + return Response( - {"detail": "카트에 상품이 없습니다."}, status=status.HTTP_404_NOT_FOUND + { + "detail": "주문이 성공적으로 생성되었습니다.", + "data": serializer.data, + }, + status=status.HTTP_201_CREATED, ) + except ValidationError as e: + return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST) - def get(self, request, item_id=None): - """ - 장바구니에 있는 상품들을 조회하며, item_id가 주어지면 개별 상품을 조회합니다. - """ - cart = self.get_cart(request.user) - if item_id is None: - serializer = CartSerializer(cart) - else: - cart_item = self.get_cart_item(cart, item_id) - if cart_item is None: - return Response( - {"detail": "장바구니에 상품이 없습니다."}, - status=status.HTTP_404_NOT_FOUND, - ) - serializer = CartItemSerializer(cart_item) +@extend_schema_view( + get=extend_schema( + summary="사용자의 청구 주소 목록 또는 특정 청구 주소를 조회하는 API", + description="사용자의 모든 청구 주소 목록을 조회하거나, 특정 청구 주소를 조회합니다.", + responses={200: UserBillingAddressSerializer(many=True)}, + ), + post=extend_schema( + summary="새로운 청구 주소를 생성하는 API", + description="새로운 청구 주소를 생성합니다. 바로 기본 청구 주소로 설정됩니다.", + responses={201: UserBillingAddressSerializer}, + ), + put=extend_schema( + summary="특정 청구 주소를 수정하는 API", + description="청구 주소 ID를 기반으로 특정 청구 주소를 수정합니다.", + responses={200: UserBillingAddressSerializer}, + ), + delete=extend_schema( + summary="특정 청구 주소를 삭제하는 API", + description="청구 주소 ID를 기반으로 특정 청구 주소를 삭제합니다.", + responses={204: None}, + ), +) +class UserBillingAddressView(UserBillingAddressMixin, generics.GenericAPIView): + """ + 청구 주소 관련 기능을 처리합니다. + + [GET /billing-addresses/]: 사용자의 모든 청구 주소 목록을 조회합니다. + [GET /billing-addresses/{billing_address_id}/]: 사용자의 특정 청구 주소를 조회합니다. + [POST /billing-addresses/]: 새로운 청구 주소를 생성합니다. + [PUT /billing-addresses/{billing_address_id}/]: 특정 청구 주소를 수정합니다. + [DELETE /billing-addresses/{billing_address_id}/]: 특정 청구 주소를 삭제합니다. + """ + + serializer_class = UserBillingAddressSerializer + permission_classes = [IsAuthenticated] + def get_queryset(self): + return UserBillingAddress.objects.filter(user=self.request.user) + + def get(self, request, pk=None): + if pk: + instance = self.get_billing_address(request.user, pk=pk) + serializer = self.get_serializer(instance) + else: + queryset = UserBillingAddress.objects.filter(user=request.user) + serializer = self.get_serializer(queryset, many=True) return Response(serializer.data) def post(self, request): - """ - 장바구니에 새 상품을 추가합니다. - """ - cart = self.get_cart(request.user) - curriculum_id = request.data.get("curriculum_id") - quantity = int(request.data.get("quantity", 1)) + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + return self.create_billing_address(request.user, serializer) + + def put(self, request, pk): + instance = self.get_billing_address(request.user, pk=pk) + serializer = self.get_serializer(instance, data=request.data) + serializer.is_valid(raise_exception=True) + return self.update_billing_address(instance, serializer) + + def delete(self, request, pk): + instance = self.get_billing_address(request.user, pk=pk) + return self.delete_billing_address(instance) + + +@extend_schema_view( + post=extend_schema( + summary="결제를 생성하고 카카오페이 결제를 요청하는 API", + description="현재 진행 중인 (pending 상태의) 주문에 대한 결제를 생성하고 카카오페이 결제를 요청합니다.", + responses={201: PaymentSerializer}, + ), + get=extend_schema( + summary="카카오페이 결제 처리 API", + description="카카오페이 결제 결과를 처리합니다.", + responses={200: PaymentSerializer}, + ), + delete=extend_schema( + summary="결제 취소 및 환불 API", + description="결제를 취소하고 환불을 처리합니다.", + responses={200: PaymentSerializer}, + ), +) +class PaymentView(PaymentMixin, OrderMixin, generics.GenericAPIView): + """ + 결제 관련 기능을 처리합니다. + + [POST /payments/]: 현재 진행 중인 주문에 대한 결제를 생성하고 카카오페이 결제를 요청합니다. + [GET /payments/]: 카카오페이 결제 결과를 처리합니다. + [DELETE /payments//cancel/]: 결제를 취소하고 환불을 처리합니다. + """ + + serializer_class = PaymentSerializer + permission_classes = [IsOwnerPermission] + + def get_queryset(self): + return Order.objects.filter(user=self.request.user) - if quantity < 0: + @transaction.atomic + def post(self, request): + order = ( + self.get_queryset() + .filter(order_status="pending") + .select_for_update() + .first() + ) + if not order: return Response( - {"detail": "수량은 0 이상이어야 합니다."}, - status=status.HTTP_400_BAD_REQUEST, + {"detail": "진행 중인 주문이 없습니다."}, + status=status.HTTP_404_NOT_FOUND, ) try: - curriculum = Curriculum.objects.get(id=curriculum_id) - except Curriculum.DoesNotExist: + payment, kakao_response = self.create_payment(order, request.user) + except ValidationError as e: + return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST) + + serializer = self.get_serializer(payment) + return Response( + { + "payment": serializer.data, + "next_redirect_pc_url": kakao_response["next_redirect_pc_url"], + "next_redirect_mobile_url": kakao_response["next_redirect_mobile_url"], + "next_redirect_app_url": kakao_response["next_redirect_app_url"], + }, + status=status.HTTP_201_CREATED, + ) + + @transaction.atomic + def get(self, request): + order = ( + self.get_queryset() + .filter(order_status="pending") + .select_for_update() + .first() + ) + if not order: return Response( - {"detail": "해당되는 커리큘럼이 없습니다."}, + {"detail": "진행 중인 주문이 없습니다."}, status=status.HTTP_404_NOT_FOUND, ) - cart_item, created = CartItem.objects.update_or_create( - cart=cart, - curriculum=curriculum, - defaults={"quantity": quantity}, + payment = ( + Payment.objects.filter(order=order, payment_status="pending") + .order_by("-created_at") + .first() ) + if not payment: + return Response( + {"detail": "해당 주문에 대한 대기 중인 결제를 찾을 수 없습니다."}, + status=status.HTTP_404_NOT_FOUND, + ) - serializer = CartItemSerializer(cart_item) - response_data = serializer.data - if created: - response_data["message"] = "장바구니에 새로운 상품이 추가되었습니다." - else: - response_data["message"] = "장바구니의 상품이 업데이트되었습니다." - - return Response( - response_data, - status=status.HTTP_201_CREATED if created else status.HTTP_200_OK, - ) + result = request.GET.get("result") + pg_token = request.GET.get("pg_token") - def delete(self, request, item_id): - """ - 장바구니에 있는 상품을 삭제합니다. - """ - cart = self.get_cart(request.user) - cart_item = self.get_cart_item(cart, item_id) + if result == "success": + try: + self.process_payment(order, payment, pg_token) + serializer = self.get_serializer(payment) + return Response( + { + "detail": "결제가 성공적으로 완료되었습니다.", + "data": serializer.data, + }, + status=status.HTTP_200_OK, + ) + except ValidationError as e: + return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST) + elif result == "cancel": + self.cancel_payment(order, payment) + serializer = self.get_serializer(payment) + return Response( + {"detail": "결제 과정이 취소되었습니다.", "data": serializer.data}, + status=status.HTTP_200_OK, + ) + elif result == "fail": + self.fail_payment(payment) + serializer = self.get_serializer(payment) + return Response( + { + "detail": "결제 처리 중 오류가 발생했습니다. 나중에 다시 시도해 주세요.", + "data": serializer.data, + }, + status=status.HTTP_400_BAD_REQUEST, + ) + else: + return Response( + {"detail": "올바르지 않은 결제 결과입니다. 다시 시도해 주세요."}, + status=status.HTTP_400_BAD_REQUEST, + ) - if cart_item is None: + @transaction.atomic + def delete(self, request, order_id): + try: + order = ( + self.get_queryset().filter(order_status="completed").get(id=order_id) + ) + except Order.DoesNotExist: return Response( - {"detail": "해당되는 상품이 없습니다."}, + {"detail": "결제된 주문을 찾을 수 없습니다."}, status=status.HTTP_404_NOT_FOUND, ) - cart_item.delete() + payment = self.get_payment( + request.user, order=order, payment_status="completed" + ) + + try: + self.refund_payment(order, payment) + except ValidationError as e: + return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST) + + serializer = self.get_serializer(payment) return Response( - {"message": "장바구니에서 상품이 삭제되었습니다."}, - status=status.HTTP_204_NO_CONTENT, + {"detail": "결제가 성공적으로 환불되었습니다.", "data": serializer.data}, + status=status.HTTP_200_OK, ) + + +@extend_schema_view( + get=extend_schema( + summary="영수증 목록 조회 또는 상세 조회 API", + description="사용자가 결제 완료/환불 한 모든 영수증 목록을 조회하거나, 특정 결제에 대한 상세 영수증 정보를 조회합니다.", + responses={200: PaymentSerializer(many=True)}, + ), +) +class ReceiptView(ReceiptMixin, PaymentMixin, generics.GenericAPIView): + """ + 영수증 관련 기능을 처리합니다. + + [GET /receipts/]: 사용자의 모든 영수증 목록을 조회합니다. + [GET /receipts/{payment_id}/]: 특정 결제에 대한 상세 영수증 정보를 조회합니다. + """ + + serializer_class = PaymentSerializer + permission_classes = [IsOwnerPermission] + + def get_queryset(self): + return Order.objects.filter(user=self.request.user) + + def get(self, request, payment_id=None): + if payment_id is None: + receipt_list = self.get_receipt_list(request.user) + return Response(receipt_list) + else: + payment = self.get_payment(request.user, id=payment_id) + receipt_detail = self.get_receipt_detail(payment, request.user) + receipt_detail["id"] = payment.id + return Response(receipt_detail) diff --git a/requirements.txt b/requirements.txt index 6289c02..ab041d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,41 +1,66 @@ asgiref==3.8.1 attrs==24.2.0 +boto3==1.35.35 +botocore==1.35.35 +certifi==2024.8.30 cffi==1.17.1 +charset-normalizer==3.3.2 colorama==0.4.6 cryptography==43.0.1 +defusedxml==0.8.0rc2 +dj-rest-auth==6.0.0 Django==5.1.1 +django-allauth==65.0.2 django-appconf==1.0.6 django-cors-headers==4.4.0 +django-filter==24.3 +django-seed==0.3.1 django-storages==1.14.4 django-video-encoding==1.0.0 djangorestframework==3.15.2 drf-spectacular==0.27.2 drf-yasg==1.21.7 +Faker==30.3.0 +ffmpeg-python==0.2.0 +future==1.0.0 +idna==3.10 inflection==0.5.1 iniconfig==2.0.0 +jmespath==1.0.1 jsonschema==4.23.0 jsonschema-specifications==2023.12.1 -jwt==1.3.1 model-bakery==1.19.5 mypy==1.11.2 mypy-extensions==1.0.0 +numpy==2.1.2 +oauthlib==3.2.2 +opencv-python==4.10.0.84 packaging==24.1 pillow==10.4.0 pluggy==1.5.0 -pip==24.0 -PyJWT==2.9.0 psycopg==3.2.2 psycopg-binary==3.2.2 pycparser==2.22 +PyJWT==2.9.0 pytest==8.3.3 pytest-django==4.9.0 +python-dateutil==2.9.0.post0 python-dotenv==1.0.1 +python3-openid==3.2.0 pytz==2024.2 PyYAML==6.0.2 referencing==0.35.1 +requests==2.32.3 +requests-oauthlib==2.0.0 rpds-py==0.20.0 +s3transfer==0.10.2 setuptools==75.1.0 +six==1.16.0 +social-auth-app-django==5.4.2 +social-auth-core==4.5.4 sqlparse==0.5.1 +toposort==1.10 typing_extensions==4.12.2 tzdata==2024.2 uritemplate==4.1.1 +urllib3==2.2.3 diff --git a/weaverse/settings.py b/weaverse/settings.py index 1ed7d5e..6b2c4b6 100644 --- a/weaverse/settings.py +++ b/weaverse/settings.py @@ -13,23 +13,46 @@ ALLOWED_HOSTS = os.getenv("ALLOWED_HOSTS", "").split(",") +# 카카오페이 연동 설정 +BASE_URL = os.environ.get("BASE_URL") +KAKAOPAY_CID = os.environ.get("KAKAOPAY_CID") +KAKAOPAY_SECRET_KEY = os.environ.get("KAKAOPAY_SECRET_KEY") + INSTALLED_APPS = [ + # 기본 장고 앱 "django.contrib.admin", "django.contrib.auth", "django.contrib.contenttypes", "django.contrib.sessions", "django.contrib.messages", "django.contrib.staticfiles", + # 써드 파티 앱 "rest_framework", - "accounts", "drf_spectacular", + "corsheaders", + "storages", + # 로컬 앱 + "accounts", "jwtauth", "courses", "materials", "payments", + "django_filters", + # social login + "social_django", + "django.contrib.sites", + "rest_framework.authtoken", + "allauth", + "allauth.account", + "allauth.socialaccount", + "allauth.socialaccount.providers.google", + "allauth.socialaccount.providers.kakao", + "dj_rest_auth", + "dj_rest_auth.registration", ] MIDDLEWARE = [ + "corsheaders.middleware.CorsMiddleware", "django.middleware.security.SecurityMiddleware", "django.contrib.sessions.middleware.SessionMiddleware", "django.middleware.common.CommonMiddleware", @@ -37,6 +60,7 @@ "django.contrib.auth.middleware.AuthenticationMiddleware", "django.contrib.messages.middleware.MessageMiddleware", "django.middleware.clickjacking.XFrameOptionsMiddleware", + "allauth.account.middleware.AccountMiddleware", ] ROOT_URLCONF = "weaverse.urls" @@ -58,16 +82,17 @@ ] REST_FRAMEWORK = { - "DEFAULT_RENDERER_CLASSES": ( + "DEFAULT_FILTER_BACKENDS": ["django_filters.rest_framework.DjangoFilterBackend"], + "DEFAULT_RENDERER_CLASSES": [ "rest_framework.renderers.JSONRenderer", - "rest_framework.renderers.BrowsableAPIRenderer", # 이 옵션이 있어야 브라우저에서 API를 시각적으로 볼 수 있음 - ), - "DEFAULT_AUTHENTICATION_CLASSES": ( - "rest_framework.authentication.TokenAuthentication", + "rest_framework.renderers.BrowsableAPIRenderer", + ], + "DEFAULT_AUTHENTICATION_CLASSES": [ "jwtauth.authentication.JWTAuthentication", - ), - "DEFAULT_PARSER_CLASSES": ("rest_framework.parsers.JSONParser",), - "DEFAULT_SCHEMA_CLASS": ("drf_spectacular.openapi.AutoSchema",), + ], + "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", + "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination", + "PAGE_SIZE": 10, } WSGI_APPLICATION = "weaverse.wsgi.application" @@ -83,11 +108,6 @@ } } -REST_FRAMEWORK = { - "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", - "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination", - "PAGE_SIZE": 10, -} AUTH_PASSWORD_VALIDATORS = [ { @@ -112,6 +132,8 @@ USE_TZ = True + +# 정적 파일 설정 STATIC_URL = "static/" STATIC_ROOT = os.getenv("STATIC_ROOT", BASE_DIR / "static") @@ -121,10 +143,20 @@ else: STATICFILES_DIRS = [BASE_DIR / "staticfiles"] +# CSRF 설정 CSRF_TRUSTED_ORIGINS = [ "https://www.weaverse.site", ] +# CORS 설정 +if DEBUG: + CORS_ALLOWED_ORIGINS = [ + "https://www.weaverse.site", # 프로덕션 환경 + "http://localhost:3000", # 개발 환경 프론트엔드 + ] +else: + CORS_ALLOWED_ORIGINS = os.getenv("CORS_ALLOWED_ORIGINS", "").split(",") + DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" AUTH_USER_MODEL = "accounts.CustomUser" @@ -138,3 +170,71 @@ "SERVE_URLCONF": "weaverse.urls", "EXTERNAL_DOCS": {"description": "Weaverse GitHub", "url": ""}, } + +# S3 설정 +AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") +AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") +AWS_STORAGE_BUCKET_NAME = os.getenv("AWS_STORAGE_BUCKET_NAME") +AWS_S3_REGION_NAME = os.getenv("AWS_S3_REGION_NAME") +AWS_S3_CUSTOM_DOMAIN = f"{AWS_STORAGE_BUCKET_NAME}.s3.amazonaws.com" +AWS_S3_OBJECT_PARAMETERS = { + "CacheControl": "max-age=86400", +} + +# boto3 설정 +DEFAULT_FILE_STORAGE = "storages.backends.s3boto3.S3Boto3Storage" + +SITE_ID = 1 + +AUTHENTICATION_BACKENDS = [ + "django.contrib.auth.backends.ModelBackend", + "allauth.account.auth_backends.AuthenticationBackend", +] + +SOCIALACCOUNT_PROVIDERS = { + "google": { + "SCOPE": [ + "profile", + "email", + ], + "AUTH_PARAMS": { + "access_type": "online", + }, + "APP": { + "client_id": os.getenv("SOCIAL_AUTH_GOOGLE_CLIENT_ID"), + "secret": os.getenv("SOCIAL_AUTH_GOOGLE_SECRET"), + "key": "", + }, + }, + "kakao": { + "SCOPE": [ + "profile", + "account_email", + ], + "APP": { + "client_id": os.getenv("SOCIAL_AUTH_KAKAO_CLIENT_ID"), + "secret": "", + "key": "", + }, + }, +} + + +ACCOUNT_USER_MODEL_USERNAME_FIELD = None +ACCOUNT_AUTHENTICATION_METHOD = "email" +ACCOUNT_EMAIL_REQUIRED = True +ACCOUNT_USERNAME_REQUIRED = False + + +REST_USE_JWT = True +JWT_AUTH_COOKIE = "my-app-auth" +JWT_AUTH_REFRESH_COOKIE = "my-refresh-token" + +GOOGLE_CALLBACK_URL = "https://www.weaverse.site/social-login/google/" + +SOCIAL_AUTH_KAKAO_KEY = os.getenv("SOCIAL_AUTH_KAKAO_KEY") + +REDIRECT_URL = "https://www.weaverse.site" +LOGIN_REDIRECT_URL = "/dashboard/" +LOGOUT_REDIRECT_URL = "/" +MEDIA_URL = f"https://{AWS_S3_CUSTOM_DOMAIN}/"