Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""add credit_balances table

Revision ID: a1b2c3d4e5f6
Revises: 8ece21fbeb47
Create Date: 2025-07-14 00:00:00.000000

"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import func


# revision identifiers, used by Alembic.
revision = 'a1b2c3d4e5f6'
down_revision = '8ece21fbeb47'
branch_labels = None
depends_on = None


def upgrade() -> None:
# Create credit_balances table
op.create_table(
'credit_balances',
sa.Column('id', sa.BigInteger(), autoincrement=True),
sa.Column('address', sa.String(), nullable=False),
sa.Column('amount', sa.BigInteger(), nullable=False),
sa.Column('ratio', sa.DECIMAL(), nullable=True),
sa.Column('tx_hash', sa.String(), nullable=True),
sa.Column('token', sa.String(), nullable=True),
sa.Column('chain', sa.String(), nullable=True),
sa.Column('provider', sa.String(), nullable=True),
sa.Column('origin', sa.String(), nullable=True),
sa.Column('payment_ref', sa.String(), nullable=True),
sa.Column('payment_method', sa.String(), nullable=True),
sa.Column('distribution_ref', sa.String(), nullable=False),
sa.Column('distribution_index', sa.Integer(), nullable=False),
sa.Column('expiration_date', sa.TIMESTAMP(timezone=True), nullable=True),
sa.Column('last_update', sa.TIMESTAMP(timezone=True), nullable=False,
server_default=func.now(), onupdate=func.now()),
sa.PrimaryKeyConstraint('distribution_ref', 'distribution_index'),
)

# Create index on address for efficient lookups
op.create_index(op.f('ix_credit_balances_address'), 'credit_balances', ['address'], unique=False)

# Add unique constraint on tx_hash to prevent duplicate credit lines (when tx_hash is not null)
op.execute(
"""
ALTER TABLE credit_balances ADD CONSTRAINT credit_balances_tx_hash_uindex
UNIQUE (tx_hash)
"""
)


def downgrade() -> None:
# Drop the credit_balances table and its constraints
op.drop_index('ix_credit_balances_address', 'credit_balances')
op.drop_constraint('credit_balances_tx_hash_uindex', 'credit_balances')
op.drop_table('credit_balances')
8 changes: 8 additions & 0 deletions src/aleph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def get_defaults():
# POST message type for balance updates.
"post_type": "balances-update",
},
"credit_balances": {
# Addresses allowed to publish credit balance updates.
"addresses": [
"0x214061ffe7e365cA37956D091C807757B4d23427",
],
# POST message types for credit balance updates.
"post_types": ["aleph_credit_distribution", "aleph_credit_airdrop"],
},
"jobs": {
"pending_messages": {
# Maximum number of retries for a message.
Expand Down
202 changes: 200 additions & 2 deletions src/aleph/db/accessors/balances.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import datetime as dt
from decimal import Decimal
from io import StringIO
from typing import Dict, Mapping, Optional, Sequence
from typing import Any, Dict, Mapping, Optional, Sequence

from aleph_message.models import Chain
from sqlalchemy import func, select
from sqlalchemy.sql import Select

from aleph.db.models import AlephBalanceDb
from aleph.db.models import AlephBalanceDb, AlephCreditBalanceDb
from aleph.toolkit.timestamp import utc_now
from aleph.types.db_session import DbSession

Expand Down Expand Up @@ -187,3 +187,201 @@ def get_updated_balance_accounts(session: DbSession, last_update: dt.datetime):
.distinct()
)
return (session.execute(select_stmt)).scalars().all()


def get_credit_balance(session: DbSession, address: str) -> int:
now = utc_now()

# Sum all non-expired credit balances for the address
result = session.execute(
select(func.sum(AlephCreditBalanceDb.amount)).where(
(AlephCreditBalanceDb.address == address)
& (
(AlephCreditBalanceDb.expiration_date.is_(None))
| (AlephCreditBalanceDb.expiration_date > now)
)
)
).scalar()

return result if result is not None else 0


def get_credit_balances(
session: DbSession,
page: int = 1,
pagination: int = 100,
min_balance: int = 0,
**kwargs,
):
now = utc_now()

# Get aggregated non-expired credit balances by address
subquery = (
select(
AlephCreditBalanceDb.address,
func.sum(AlephCreditBalanceDb.amount).label("credits"),
)
.where(
(AlephCreditBalanceDb.expiration_date.is_(None))
| (AlephCreditBalanceDb.expiration_date > now)
)
.group_by(AlephCreditBalanceDb.address)
).subquery()

query = select(subquery.c.address, subquery.c.credits)

if min_balance > 0:
query = query.filter(subquery.c.credits >= min_balance)

query = query.offset((page - 1) * pagination)

if pagination:
query = query.limit(pagination)

return session.execute(query).all()


def count_credit_balances(session: DbSession, min_balance: int = 0, **kwargs):
from aleph.toolkit.timestamp import utc_now

now = utc_now()

# Count unique addresses with non-expired credit balances
subquery = (
select(AlephCreditBalanceDb.address)
.where(AlephCreditBalanceDb.expiration_date > now)
.group_by(AlephCreditBalanceDb.address)
)

if min_balance > 0:
subquery = subquery.having(func.sum(AlephCreditBalanceDb.amount) >= min_balance)

query = select(func.count()).select_from(subquery.subquery())

return session.execute(query).scalar_one()


def update_credit_balances(
session: DbSession,
credits_list: Sequence[Dict[str, Any]],
token: str,
chain: str,
message_hash: str,
) -> None:
"""
Updates multiple credit balances at the same time, efficiently.

Similar to update_balances, this uses a temporary table and bulk operations
for better performance.
"""

last_update = utc_now()

session.execute(
"CREATE TEMPORARY TABLE temp_credit_balances AS SELECT * FROM credit_balances WITH NO DATA" # type: ignore[arg-type]
)

conn = session.connection().connection
cursor = conn.cursor()

# Prepare an in-memory CSV file for use with the COPY operator
csv_rows = []
for index, credit_entry in enumerate(credits_list):
address = credit_entry["address"]
amount = int(credit_entry["amount"]) # Cast to integer
ratio = Decimal(credit_entry["ratio"])
tx_hash = credit_entry["tx_hash"]
provider = credit_entry["provider"]

# Extract optional fields from each credit entry
expiration_timestamp = credit_entry.get("expiration", "")
origin = credit_entry.get("origin", "")
payment_ref = credit_entry.get("ref", "")
payment_method = credit_entry.get("payment_method", "")

# Convert expiration timestamp to datetime

expiration_date = (
dt.datetime.fromtimestamp(expiration_timestamp / 1000, tz=dt.timezone.utc)
if expiration_timestamp != ""
else None
)

csv_rows.append(
f"{address};{amount};{ratio};{tx_hash};{expiration_date or ''};{token};{chain};{origin};{provider};{payment_ref};{payment_method};{message_hash};{index};{last_update}"
)

csv_credit_balances = StringIO("\n".join(csv_rows))
cursor.copy_expert(
"COPY temp_credit_balances(address, amount, ratio, tx_hash, expiration_date, token, chain, origin, provider, payment_ref, payment_method, distribution_ref, distribution_index, last_update) FROM STDIN WITH CSV DELIMITER ';'",
csv_credit_balances,
)
session.execute(
"""
INSERT INTO credit_balances(address, amount, ratio, tx_hash, expiration_date, token, chain, origin, provider, payment_ref, payment_method, distribution_ref, distribution_index, last_update)
(SELECT address, amount, ratio, tx_hash, expiration_date, token, chain,
NULLIF(origin, ''), provider, NULLIF(payment_ref, ''), NULLIF(payment_method, ''), distribution_ref, distribution_index, last_update FROM temp_credit_balances)
ON CONFLICT ON CONSTRAINT credit_balances_tx_hash_uindex DO NOTHING
""" # type: ignore[arg-type]
)

# Drop the temporary table
session.execute("DROP TABLE temp_credit_balances") # type: ignore[arg-type]


def update_credit_balances_airdrop(
session: DbSession,
credits_list: Sequence[Dict[str, Any]],
message_hash: str,
) -> None:
"""
Updates multiple credit balances from airdrop messages.

Similar to update_credit_balances, this uses a temporary table and bulk operations
for better performance. The airdrop schema doesn't include token/chain/ratio fields.
"""

last_update = utc_now()

session.execute(
"CREATE TEMPORARY TABLE temp_credit_balances AS SELECT * FROM credit_balances WITH NO DATA" # type: ignore[arg-type]
)

conn = session.connection().connection
cursor = conn.cursor()

# Prepare an in-memory CSV file for use with the COPY operator
csv_rows = []
for index, credit_entry in enumerate(credits_list):
address = credit_entry["address"]
amount = int(credit_entry["amount"]) # Cast to integer
origin = credit_entry.get("origin", "")
expiration_timestamp = credit_entry.get("expiration", 0)

# Convert expiration timestamp to datetime
expiration_date = (
dt.datetime.fromtimestamp(expiration_timestamp / 1000, tz=dt.timezone.utc)
if expiration_timestamp > 0
else None
)

csv_rows.append(
f"{address};{amount};;{expiration_date or ''};;;;{origin};;;;;{message_hash};{index};{last_update}"
)

csv_credit_balances = StringIO("\n".join(csv_rows))
cursor.copy_expert(
"COPY temp_credit_balances(address, amount, ratio, tx_hash, expiration_date, token, chain, origin, provider, payment_ref, payment_method, distribution_ref, distribution_index, last_update) FROM STDIN WITH CSV DELIMITER ';'",
csv_credit_balances,
)
session.execute(
"""
INSERT INTO credit_balances(address, amount, ratio, tx_hash, expiration_date, token, chain, origin, provider, payment_ref, payment_method, distribution_ref, distribution_index, last_update)
(SELECT address, amount, NULLIF(ratio, ''), NULLIF(tx_hash, ''), expiration_date, NULLIF(token, ''), NULLIF(chain, ''),
NULLIF(origin, ''), NULLIF(provider, ''), NULLIF(payment_ref, ''), NULLIF(payment_method, ''), distribution_ref, distribution_index, last_update FROM temp_credit_balances)
ON CONFLICT ON CONSTRAINT credit_balances_distribution_ref_index_uindex DO NOTHING
""" # type: ignore[arg-type]
)

# Drop the temporary table
session.execute("DROP TABLE temp_credit_balances") # type: ignore[arg-type]
32 changes: 32 additions & 0 deletions src/aleph/db/models/balances.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,35 @@ class AlephBalanceDb(Base):
"address", "chain", "dapp", name="balances_address_chain_dapp_uindex"
),
)


class AlephCreditBalanceDb(Base):
__tablename__ = "credit_balances"

id: int = Column(BigInteger, autoincrement=True)

address: str = Column(String, nullable=False, index=True)
amount: int = Column(BigInteger, nullable=False)
ratio: Optional[Decimal] = Column(DECIMAL, nullable=True)
tx_hash: Optional[str] = Column(String, nullable=True)
token: Optional[str] = Column(String, nullable=True)
chain: Optional[str] = Column(String, nullable=True)
provider: Optional[str] = Column(String, nullable=True)
origin: Optional[str] = Column(String, nullable=True)
payment_ref: Optional[str] = Column(String, nullable=True)
payment_method: Optional[str] = Column(String, nullable=True)
distribution_ref: str = Column(String, nullable=False, primary_key=True)
distribution_index: int = Column(Integer, nullable=False, primary_key=True)
expiration_date: Optional[dt.datetime] = Column(
TIMESTAMP(timezone=True), nullable=True
)
last_update: dt.datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.now(),
)

__table_args__ = (
UniqueConstraint("tx_hash", name="credit_balances_tx_hash_uindex"),
)
Loading
Loading