Source code for graphtools.dbgraph

import gengraph as gg
from sqlalchemy import Table, Column, Integer, String, MetaData, ForeignKey, func, create_engine, sql

[docs]def make_db(arrows_list, name=None): """ Make a SQLite database in the correct format for :class:`DBGraph`. Parameters __________ :param list arrow_list: the arrows of the graph as a list of lists of two ints, the first int representing the tail of the arrow and the second the head. :param str name: the name of the database; if no name is given, the database will be in-memory-only Returns _______ :returns: the users table, arrows table, database connection and engine as two `sqlalchemy Table`_ objects, a `sqlalchemy Connection`_ object, and a `sqlalchemy Engine`_ object, respectively :rtype: tuple For example, >>> from graphtools.dbgraph import make_db >>> users, arrows, conn, eng = make_db([[1, 2], [2, 3]]) >>> from sqlalchemy.sql import select >>> result = conn.execute(select([users])) >>> result.fetchall() [(1, 0), (2, 0), (3, 0)] >>> result = conn.execute(select([arrows])) >>> result.fetchall() [(1, 1, 2), (2, 2, 3)] .. _sqlalchemy Table: http://docs.sqlalchemy.org/en/rel_0_9/core/metadata.html#sqlalchemy.schema.Table .. _sqlalchemy Connection: http://docs.sqlalchemy.org/en/rel_0_9/core/connections.html#sqlalchemy.engine.Connection .. _sqlalchemy Engine: http://docs.sqlalchemy.org/en/rel_0_9/core/connections.html#sqlalchemy.engine.Engine """ #make the engine and connection if name is None: engine = create_engine('sqlite://') else: engine = create_engine('sqlite:///{}.db'.format(name)) conn = engine.connect() #setup the tables metadata = MetaData() users = Table('users', metadata, Column('user_id', Integer, primary_key=True), Column('rank', Integer) ) arrows = Table('arrows', metadata, Column('id', Integer, primary_key=True), Column('follow_id', Integer, ForeignKey('users.user_id')), Column('lead_id', Integer, ForeignKey('users.user_id')), sqlite_autoincrement=True ) #create the tables metadata.create_all(engine) #add the vertices vertices = set([a[0] for a in arrows_list] + [a[1] for a in arrows_list]) user_values = [{'user_id':vertex, 'rank':0} for vertex in vertices] conn.execute(users.insert(), user_values) #add the arrows arrow_values = [{'follow_id':arrow[0], 'lead_id':arrow[1]} for arrow in arrows_list] conn.execute(arrows.insert(), arrow_values) return users, arrows, conn, engine #make a test table
testusers, testarrows, testconn, testeng = make_db([[1, 2], [2, 3]])
[docs]def get_tables(engine, userstablename='users', arrowstablename='arrows'): """ Get the tables and connections required for :class:`DBGraph` from a `sqlalchemy Engine`_ object. Parameters: ___________ :param sqlalchemy.engine.Engine engine: the database engine :param str userstablename: the name of the users table :param str arrowstablename: the name of the arrows table Returns _______ :returns: the users table, arrows table and database connection as two `sqlalchemy Table`_ objects and a `sqlalchemy Connection`_ object, respectively :rtype: tuple For example, >>> from graphtools.dbgraph import testeng, get_tables >>> users, arrows, conn = get_tables(testeng) >>> from sqlalchemy.sql import select >>> result = conn.execute(select([users])) >>> result.fetchall() [(1, 0), (2, 0), (3, 0)] >>> result = conn.execute(select([arrows])) >>> result.fetchall() [(1, 1, 2), (2, 2, 3)] .. _sqlalchemy Table: http://docs.sqlalchemy.org/en/rel_0_9/core/metadata.html#sqlalchemy.schema.Table .. _sqlalchemy Engine: http://docs.sqlalchemy.org/en/rel_0_9/core/connections.html#sqlalchemy.engine.Engine .. _sqlalchemy Connection: http://docs.sqlalchemy.org/en/rel_0_9/core/connections.html#sqlalchemy.engine.Connection """ meta = MetaData() meta.reflect(bind=engine) users = meta.tables[userstablename] arrows = meta.tables[arrowstablename] conn = engine.connect() return users, arrows, conn
[docs]class DBGraph(gg.GenGraph): """ A subclass of GenGraph for graphs stored in databases. The vertices are stored in a table called **users** with columns * *user_id* (any type) and * *rank* (int) (the name comes from the original motivation which was Twitter user subgraphs). The table may also have a column *group* (any type) specifying the particular graph that the user belongs to if the database contains multiple graphs. If the table has no *group* column, *user_id* should be a unique identifier; if there is a *group* column, *user_id* and *group* together should be unique. The arrows are stored in a table called **arrows** with columns * *follow_id* and * *lead_id* both refering to **users**.\ *user_id*. If **users** has a *group* column then **arrows** should have a corresponding *group* column. Parameters __________ :param sqlalchemy.schema.Table users: the table of vertices, described above :param sqlalchemy.schema.Table arrows: the table of arrows, described above :param sqlalchemy.engine.Connection conn: a connection to the database :param group: an optional identifier, described above The initializer sets all entries in **user**.\ *rank* to 0. For example, >>> from graphtools.dbgraph import make_db, DBGraph >>> users, arrows, conn, eng = make_db([[1, 2], [2, 3]]) >>> graph = DBGraph(users=users, arrows=arrows, conn=conn) >>> print graph.get_num_arrows() 2 >>> set(graph.get_vert_list()) == set([1, 2, 3]) True >>> print graph.get_rank(1) 0 >>> graph.set_rank(3,2) >>> graph.get_rank(3) 2 >>> graph.reset_ranks() >>> graph.descend(2) >>> graph.descent(20) >>> hl = graph.hierarchy_list #get the list of hierarchy scores >>> print len(hl) #descend has been run 21 times, plus the initial score 22 >>> print hl[0] #the first score is always 0 0 >>> print hl[-1] #the score after 21 descends will probably be 1.0 1.0 """ def __init__(self, users, arrows, conn, group = None): """Initialize the object.""" gg.GenGraph.__init__(self) self.users = users self.arrows = arrows self.conn = conn self.group = group if self.group is None: self.ucheckgroup = True self.acheckgroup = True else: self.ucheckgroup = self.users.c.group == self.group self.acheckgroup = self.arrows.c.group == self.group self.reset_ranks() def reset_ranks(self): stmt = self.users.update().where(self.ucheckgroup).values(rank = 0) self.conn.execute(stmt) def get_num_arrows(self): countarrows = sql.select([func.count()]).select_from(self.arrows).where(self.acheckgroup) result = self.conn.execute(countarrows) return result.fetchone()[0] def get_vert_list(self): getuserids = sql.select([self.users.c.user_id]).where(self.ucheckgroup) results = self.conn.execute(getuserids) return [result[0] for result in results.fetchall()] def check_id(self, vert): return self.users.c.user_id == vert def get_rank(self, vert): stmt = sql.select([self.users.c.rank]).where((self.check_id(vert)) & (self.ucheckgroup)) result = self.conn.execute(stmt) return result.fetchone()[0] def set_rank(self, vert, newrank): stmt = self.users.update().where((self.check_id(vert)) & (self.ucheckgroup)).values(rank = newrank) self.conn.execute(stmt) def count_neighbors(self, vert, out=True, cond=False, less=True, cutoff=0): genstmt = sql.select([func.count()]).select_from(self.users).select_from(self.arrows) if out: inoutstmt = genstmt.where(self.arrows.c.follow_id == vert).where(self.users.c.user_id == self.arrows.c.lead_id) else: inoutstmt = genstmt.where(self.arrows.c.lead_id == vert).where(self.users.c.user_id == self.arrows.c.follow_id) if cond: if less: finalstmt = inoutstmt.where(self.users.c.rank <= cutoff) else: finalstmt = inoutstmt.where(self.users.c.rank >= cutoff) else: finalstmt = inoutstmt result = self.conn.execute(finalstmt) return result.fetchone()[0]