Skip to content

Commit

Permalink
simplify get_query using class attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
sampsyo committed Sep 10, 2013
1 parent f70ddfb commit 4ee4169
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
35 changes: 18 additions & 17 deletions beets/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,9 @@
}
SQLITE_KEY_TYPE = 'INTEGER PRIMARY KEY'

# Default search fields for various granularities.
ARTIST_DEFAULT_FIELDS = ('artist',)
# Default search fields for each model.
ALBUM_DEFAULT_FIELDS = ('album', 'albumartist', 'genre')
ITEM_DEFAULT_FIELDS = ARTIST_DEFAULT_FIELDS + ALBUM_DEFAULT_FIELDS + \
('title', 'comments')
ITEM_DEFAULT_FIELDS = ALBUM_DEFAULT_FIELDS + ('artist', 'title', 'comments')

# Special path format key.
PF_KEY_DEFAULT = 'default'
Expand Down Expand Up @@ -350,6 +348,11 @@ class LibModel(FlexModel):
strings.
"""

_search_fields = ()
"""The fields that should be queried by default by unqualified query
terms.
"""

def __init__(self, lib=None, **values):
self._lib = lib
super(LibModel, self).__init__(**values)
Expand Down Expand Up @@ -424,6 +427,7 @@ class Item(LibModel):
_fields = ITEM_KEYS
_table = 'items'
_flex_table = 'item_attributes'
_search_fields = ITEM_DEFAULT_FIELDS

@classmethod
def from_path(cls, path):
Expand Down Expand Up @@ -610,7 +614,7 @@ class Query(object):
def clause(self):
"""Generate an SQLite expression implementing the query.
Return a clause string, a sequence of substitution values for
the clause, and a Query object representing the "remainder"
the clause, and a Query object representing the "remainder"
Returns (clause, subvals) where clause is a valid sqlite
WHERE clause implementing the query and subvals is a list of
items to be substituted for ?s in the clause.
Expand Down Expand Up @@ -1101,18 +1105,13 @@ def construct_query_part(query_part, default_fields, all_keys):
else:
return query_class(key.lower(), pattern, key in all_keys)

def get_query(val, album=False):
def get_query(val, model_cls):
"""Takes a value which may be None, a query string, a query string
list, or a Query object, and returns a suitable Query object. album
determines whether the query is to match items or albums.
list, or a Query object, and returns a suitable Query object.
`model_cls` is the subclass of LibModel indicating which entity this
is a query for (i.e., Album or Item) and is used to determine which
fields are searched.
"""
if album:
default_fields = ALBUM_DEFAULT_FIELDS
all_keys = ALBUM_KEYS
else:
default_fields = ITEM_DEFAULT_FIELDS
all_keys = ITEM_KEYS

# Convert a single string into a list of space-separated
# criteria.
if isinstance(val, basestring):
Expand All @@ -1121,7 +1120,8 @@ def get_query(val, album=False):
if val is None:
return TrueQuery()
elif isinstance(val, list) or isinstance(val, tuple):
return AndQuery.from_strings(val, default_fields, all_keys)
return AndQuery.from_strings(val, model_cls._search_fields,
model_cls._fields)
elif isinstance(val, Query):
return val
else:
Expand Down Expand Up @@ -1513,7 +1513,7 @@ def _fetch(self, model_cls, order_by, query):
query. The query may be given as a string, string sequence, a
Query object, or None (to fetch everything).
"""
query = get_query(query, model_cls is Album)
query = get_query(query, model_cls)

where, subvals = query.clause()
with self.transaction() as tx:
Expand Down Expand Up @@ -1608,6 +1608,7 @@ class Album(LibModel):
_fields = ALBUM_KEYS
_table = 'albums'
_flex_table = 'album_attributes'
_search_fields = ALBUM_DEFAULT_FIELDS

def __setitem__(self, key, value):
"""Set the value of an album attribute."""
Expand Down
2 changes: 1 addition & 1 deletion beetsplug/mbsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _print_and_apply_changes(lib, item, old_data, move, pretend, write):
def mbsync_singletons(lib, query, move, pretend, write):
"""Synchronize matching singleton items.
"""
singletons_query = library.get_query(query, False)
singletons_query = library.get_query(query, library.Item)
singletons_query.subqueries.append(library.SingletonQuery(True))
for s in lib.items(singletons_query):
if not s.mb_trackid:
Expand Down

0 comments on commit 4ee4169

Please sign in to comment.