from typing import Union
import six
from django.utils.translation import gettext_lazy as _
from shuup.core.models import AnonymousContact, ProductCatalogDiscountedPrice, ShopProduct
from shuup.core.pricing import DiscountModule, PriceInfo
from shuup.discounts.utils import get_potential_discounts_for_product, get_price_expiration, index_shop_product_price
[docs]
class ProductDiscountModule(DiscountModule):
identifier = "product_discounts"
name = _("Product Discounts")
[docs]
def discount_price(self, context, product, price_info):
shop = context.shop
potential_discounts = get_potential_discounts_for_product(context, product).values_list(
"discounted_price_value",
"discount_amount_value",
"discount_percentage",
)
discounted_prices = []
for (
discounted_price_value,
discount_amount_value,
discount_percentage,
) in potential_discounts:
if discounted_price_value: # Applies the new product price per item
discounted_prices.append(
min(
price_info.price,
max(
shop.create_price(discounted_price_value) * price_info.quantity,
shop.create_price(0),
),
)
)
if discount_amount_value: # Discount amount value per item
discounted_prices.append(
max(
price_info.price - shop.create_price(discount_amount_value) * price_info.quantity,
shop.create_price(0),
)
)
if discount_percentage: # Discount percentage per item
discounted_prices.append(
max(
price_info.price - price_info.price * discount_percentage,
shop.create_price(0),
)
)
new_price_info = PriceInfo(
price=price_info.price,
base_price=price_info.base_price,
quantity=price_info.quantity,
expires_on=price_info.expires_on,
)
if discounted_prices:
product_id = product if isinstance(product, six.integer_types) else product.pk
minimum_price_values = list(
ShopProduct.objects.filter(product_id=product_id, shop=shop).values_list(
"minimum_price_value", flat=True
)
)
minimum_price_value = minimum_price_values[0] if minimum_price_values else 0
new_price_info.price = max(
min(discounted_prices),
shop.create_price(minimum_price_value or 0) or shop.create_price(0),
)
price_expiration = get_price_expiration(context, product)
if price_expiration and (not price_info.expires_on or price_expiration < price_info.expires_on):
new_price_info.expires_on = price_expiration
return new_price_info
[docs]
def index_shop_product(self, shop_product: Union["ShopProduct", int], **kwargs):
"""
Index the shop product discounts. This is a heavy procedure, use with precaution
and through some background task.
"""
if isinstance(shop_product, int):
shop_product = ShopProduct.objects.select_related("product", "shop").get(pk=shop_product)
is_variation_parent = shop_product.product.is_variation_parent()
# index the discounted price of all children shop products
if is_variation_parent:
children_shop_product = ShopProduct.objects.select_related("product", "shop").filter(
shop=shop_product.shop, product__variation_parent=shop_product.product
)
for child_shop_product in children_shop_product:
self.index_shop_product(child_shop_product)
else:
ProductCatalogDiscountedPrice.objects.filter(
catalog_rule__module_identifier=self.identifier,
shop=shop_product.shop,
product=shop_product.product,
).delete()
from shuup.discounts.models import Discount
# get the different contact groups ids to index the prices
discounts_groups_ids = list(
Discount.objects.filter(
shop=shop_product.shop,
active=True,
contact_group__isnull=False,
)
.values_list("contact_group__id", flat=True)
.distinct()
)
discounts_groups_ids.append(AnonymousContact.get_default_group().pk)
for supplier in shop_product.suppliers.all().only("pk"):
index_shop_product_price(shop_product, supplier, discounts_groups_ids)