diff --git a/ecommerce/core/tests/test_create_demo_data.py b/ecommerce/core/tests/test_create_demo_data.py index 68fb787b78a..ed4c118d5d4 100644 --- a/ecommerce/core/tests/test_create_demo_data.py +++ b/ecommerce/core/tests/test_create_demo_data.py @@ -59,7 +59,13 @@ def test_handle_with_existing_course(self): partner=self.partner ) - seat_attrs = {'certificate_type': '', 'expires': None, 'price': 0.00, 'id_verification_required': False} + seat_attrs = { + 'certificate_type': '', + 'expires': None, + 'price': 0.00, + 'id_verification_required': False, + 'variant_id': '00000000-0000-0000-0000-000000000000' + } course.create_or_update_seat(**seat_attrs) with mock.patch.object(Course, 'publish_to_lms', return_value=None) as mock_publish: diff --git a/ecommerce/courses/models.py b/ecommerce/courses/models.py index 62c10ed6c80..e5fccc3450c 100644 --- a/ecommerce/courses/models.py +++ b/ecommerce/courses/models.py @@ -155,6 +155,7 @@ def create_or_update_seat( remove_stale_modes=True, create_enrollment_code=False, sku=None, + variant_id=None, ): """ Creates and updates course seat products. @@ -218,6 +219,7 @@ def create_or_update_seat( seat.attr.certificate_type = certificate_type seat.attr.course_key = course_id seat.attr.id_verification_required = id_verification_required + seat.attr.variant_id = variant_id if certificate_type in ENROLLMENT_CODE_SEAT_TYPES and create_enrollment_code: self._create_or_update_enrollment_code( certificate_type, id_verification_required, self.partner, price, expires diff --git a/ecommerce/courses/tests/test_models.py b/ecommerce/courses/tests/test_models.py index a875b230f21..fbc293e281f 100644 --- a/ecommerce/courses/tests/test_models.py +++ b/ecommerce/courses/tests/test_models.py @@ -97,7 +97,7 @@ def test_save_creates_parent_seat(self): self.assertEqual(parent.attr.course_key, course.id) def assert_course_seat_valid(self, seat, course, certificate_type, id_verification_required, price, - credit_provider=None, credit_hours=None): + credit_provider=None, credit_hours=None, variant_id=None): """ Ensure the given seat has the correct attribute values. """ self.assertEqual(seat.structure, Product.CHILD) # pylint: disable=protected-access @@ -108,6 +108,9 @@ def assert_course_seat_valid(self, seat, course, certificate_type, id_verificati self.assertEqual(seat.attr.id_verification_required, id_verification_required) self.assertEqual(seat.stockrecords.first().price_excl_tax, price) + if variant_id: + self.assertEqual(seat.attr.variant_id, variant_id) + if credit_provider: self.assertEqual(seat.attr.credit_provider, credit_provider) @@ -132,7 +135,8 @@ def test_create_or_update_seat(self): # Test seat update price = 10 course.create_or_update_seat( - certificate_type, id_verification_required, price, sku=seat.stockrecords.first().partner_sku + certificate_type, id_verification_required, price, sku=seat.stockrecords.first().partner_sku, + variant_id='00000000-0000-0000-0000-000000000000' ) # Again, only two seats with one being the parent seat product. @@ -202,12 +206,14 @@ def test_update_credit_seat(self): certificate_type = 'credit' id_verification_required = True price = 10 + variant_id = '00000000-0000-0000-0000-000000000000' credit_seat = course.create_or_update_seat( certificate_type, id_verification_required, price, credit_provider=credit_provider, - credit_hours=credit_hours + credit_hours=credit_hours, + variant_id=variant_id ) credit_hours = 4 price = 100 @@ -218,6 +224,7 @@ def test_update_credit_seat(self): credit_provider=credit_provider, credit_hours=credit_hours, sku=credit_seat.stockrecords.first().partner_sku, + variant_id=variant_id ) self.assert_course_seat_valid( credit_seat, @@ -226,7 +233,8 @@ def test_update_credit_seat(self): id_verification_required, price, credit_provider=credit_provider, - credit_hours=credit_hours + credit_hours=credit_hours, + variant_id=variant_id ) def test_collision_avoidance(self): diff --git a/ecommerce/extensions/api/serializers.py b/ecommerce/extensions/api/serializers.py index 11825a2c891..3ea92f9c0de 100644 --- a/ecommerce/extensions/api/serializers.py +++ b/ecommerce/extensions/api/serializers.py @@ -771,6 +771,7 @@ def save(course, product, create_enrollment_code): stockrecords = product.get('stockrecords', []) if stockrecords: sku = stockrecords[0].get('partner_sku') + variant_id = attrs.get('variant_id') seat = course.create_or_update_seat( certificate_type, @@ -781,6 +782,7 @@ def save(course, product, create_enrollment_code): credit_hours=credit_hours, create_enrollment_code=create_enrollment_code, sku=sku, + variant_id=variant_id ) # As a convenience to our caller, provide the SKU in the returned product serialization.