Coder Social home page Coder Social logo

Comments (10)

ankane avatar ankane commented on August 30, 2024

Hi @lucasgadams, it returns a numpy.ndarray in the tests, but if you can create a failing test case, happy to look into it more.

avg = session.query(func.avg(Item.embedding)).first()[0]
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))

from pgvector-python.

lucasgadams avatar lucasgadams commented on August 30, 2024

Ok I will try maybe using query instead? I did find a way to get the right type hacking around in the sqlaclchemy source code. In sqlalchemy/sql/functions.py I added:

class avg(ReturnTypeFromArgs[_T]):  # noqa: A001
    """The SQL AVG() aggregate function."""

    inherit_cache = True

and:

class _FunctionGenerator:
....
        @property
        def avg(self) -> Type[avg[Any]]:  # noqa: A001
            ...
            

from pgvector-python.

lucasgadams avatar lucasgadams commented on August 30, 2024

@ankane I feel like we might want to open this up again. The query api is kind of deprecated in favor of the select api no? And when using async mode, i dont think the query api is available at all? Fairly sure that test case would fail if you tried to use async or the select api.

from pgvector-python.

lucasgadams avatar lucasgadams commented on August 30, 2024

The sqlalchemy docs suggest wrapping sync function to access the query api with an async session, like:

def query(session):
    return session.query(func.avg(models.DocumentChunkModel.embedding)).all()

And then invoking it with await async_session.run_sync(query). I tried that and still got a string type. So not sure how it is supposed to work with async. Adding those specific function typings into sql/functions.py is the only thing that worked so far.

from pgvector-python.

ankane avatar ankane commented on August 30, 2024

Looks like it's only the case with asyncio, so may want to report it to SQLAlchemy (if it hasn't already been).

from pgvector-python.

lucasgadams avatar lucasgadams commented on August 30, 2024

I could try that.

For anyone looking for a workaround, just define the avg function directly and provide the proper typing information:

from sqlalchemy.sql.functions import ReturnTypeFromArgs


# ReturnTypeFromArgs correctly uses the pgvector.Vector return type
# The np.ndarray type allows proper return type hints when used in session
class avg(ReturnTypeFromArgs[np.ndarray]):
    inherit_cache = True
    # Allows accessing from func.vector.avg, otherwise you can just use this class directly
    # like avg(select(SomeModel.embedding))
    package = "vector"


async def get_avg_embedding(session: AsyncSession):
    return await session.scalar(select(avg(models.DocumentChunkModel.embedding)))

from pgvector-python.

lucasgadams avatar lucasgadams commented on August 30, 2024

Quick question @ankane, we are using asyncpg with sqlalchemy, so our connection string looks like drivername="postgresql+asyncpg". Do we need to initialize both sqlalchemy and asyncpg with the pgvector types? So far we have just been using the Vector type from the sqlalchemy extension, put it into a mapped column and seems to work well up to this point. But now the avg function doesn't work correctly. Wondering if we should also run register_vector from the asyncpg extension?

from pgvector-python.

lucasgadams avatar lucasgadams commented on August 30, 2024

Aha that looks to be the issue. We are using asyncpg as the driver for sqlalchemy. When we query for an embedding column directly, that is correctly invoking the pgvector sqlalchemy custom type which converts a string to np array. And if you use a function like sum i believe it will also work, because the built-in sum function in sqlalchemy inherits from ReturnTypeFromArgs, so it takes the Vector type from the args and again uses the sqlalchemy custom code to convert. However avg is not registered as a built in function (not sure why but it is not in the source code, maybe its not as universal?). And so as we are currently using it, SQLAlchemy does not know what type it should be. So you can fix that by making an avg function class that inherits from ReturnTypeFromArgs. Alternatively, you could provide the asyncpg driver with the type code information, so that it understands what type it is supposed to be at a lower level. You can do that with SQLAlchemy (ref) by:

            self._engine = create_async_engine(
                ....
            )

            @event.listens_for(self._engine.sync_engine, "connect")
            def register_custom_types(dbapi_connection: AdaptedConnection, *args):
                dbapi_connection.run_async(register_vector)

I just tested that and it worked even without the custom avg function class. My main question is @ankane is it a good idea to use both the Vector type for SQLAlchemy and additionally configure the asyncpg driver? My main concern is that the SQLAlchemy type uses a string representation while the asyncpg configuration seems to use a binary representation.

from pgvector-python.

lucasgadams avatar lucasgadams commented on August 30, 2024

I believe there is currently an issue when using them both together. In terms of the call stack when querying and comitting:

querying: calls from_db_binary (asyncpg) then from_db (sqlalchemy). This is fine because there is a check in from_db that if it is already an np array, just return it.

committing: calls to_db (sqlalchemy) then to_db_binary (asyncpg). This breaks with the error ValueError: could not convert string to float: '[0.24479502...

So i do think there are things we can do in the pgvector library to make this work correctly. I dont think this is an error on the sqlalchemy side.

from pgvector-python.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.