dbtextdb.py

00001 #!/usr/bin/python
00002 #
00003 # Copyright 2008 Google Inc. All Rights Reserved.
00004 
00005 """SQL-like access layer for dbtext.
00006 
00007 This module provides the glue for kamctl to interact with dbtext files
00008 using basic SQL syntax thus avoiding special case handling of dbtext.
00009 
00010 """
00011 
00012 __author__ = 'herman@google.com (Herman Sheremetyev)'
00013 
00014 import fcntl
00015 import os
00016 import shutil
00017 import sys
00018 import tempfile
00019 import time
00020 
00021 if 'DBTEXTDB_DEBUG' in os.environ:
00022   DEBUG = os.environ['DBTEXTDB_DEBUG']
00023 else:
00024   DEBUG = 0
00025 
00026 
00027 def Debug(msg):
00028   """Debug print method."""
00029   if DEBUG:
00030     print msg
00031 
00032 
00033 class DBText(object):
00034   """Provides connection to a dbtext database."""
00035 
00036   RESERVED_WORDS = ['SELECT', 'DELETE', 'UPDATE', 'INSERT', 'SET',
00037                     'VALUES', 'INTO', 'FROM', 'ORDER', 'BY', 'WHERE',
00038                     'COUNT', 'CONCAT', 'AND', 'AS']
00039   ALL_COMMANDS = ['SELECT', 'DELETE', 'UPDATE', 'INSERT']
00040   WHERE_COMMANDS = ['SELECT', 'DELETE', 'UPDATE']
00041 
00042   def __init__(self, location):
00043     self.location = location  # location of dbtext tables
00044     self.tokens = []          # query broken up into tokens
00045     self.conditions = {}      # args to the WHERE clause
00046     self.columns = []         # columns requested by SELECT
00047     self.table = ''           # name of the table being queried
00048     self.header = {}          # table header
00049     self.orig_data = []       # original table data used to diff after updates
00050     self.data = []            # table data as a list of dicts
00051     self.count = False        # where or not using COUNT()
00052     self.aliases = {}         # column aliases (SELECT AS)
00053     self.targets = {}         # target columns-value pairs for INSERT/UPDATE
00054     self.args = ''            # query arguments preceeding the ;
00055     self.command = ''         # which command are we executing
00056     self.strings = []         # list of string literals parsed from the query
00057     self.parens = []          # list of parentheses parsed from the query
00058     self._str_placeholder = '__DBTEXTDB_PARSED_OUT_STRING__'
00059     self._paren_placeholder = '__DBTEXTDB_PARSED_OUT_PARENS__'
00060     if not os.path.isdir(location):
00061       raise ParseError(location + ' is not a directory')
00062 
00063   def _ParseOrderBy(self):
00064     """Parse out the column name to be used for ordering the dataset.
00065 
00066     Raises:
00067       ParseError: Invalid ORDER BY clause
00068     """
00069     self.order_by = ''
00070     if 'ORDER' in self.tokens:
00071       order_index = self.tokens.index('ORDER')
00072       if order_index != len(self.tokens) - 3:
00073         raise ParseError('ORDER must be followed with BY and column name')
00074       if self.tokens[order_index + 1] != 'BY':
00075         raise ParseError('ORDER must be followed with BY')
00076       self.order_by = self.tokens[order_index + 2]
00077 
00078       # strip off the order by stuff
00079       self.tokens.pop()  # column name
00080       self.tokens.pop()  # BY
00081       self.tokens.pop()  # ORDER
00082 
00083     elif 'BY' in self.tokens:
00084       raise ParseError('BY must be preceeded by ORDER')
00085 
00086     Debug('Order by: ' + self.order_by)
00087 
00088   def _ParseConditions(self):
00089     """Parse out WHERE clause.
00090 
00091     Take everything after the WHERE keyword and convert it to a dict of
00092     name value pairs corresponding to the columns and their values that
00093     should be matched.
00094 
00095     Raises:
00096       ParseError: Invalid WHERE clause
00097       NotSupportedError: Unsupported syntax
00098     """
00099     self.conditions = {}
00100     Debug('self.tokens = %s' % self.tokens)
00101     if 'WHERE' not in self.tokens:
00102       return
00103 
00104     if self.command not in self.WHERE_COMMANDS:
00105       raise ParseError(self.command + ' cannot have a WHERE clause')
00106     if 'OR' in self.tokens:
00107       raise NotSupportedError('WHERE clause does not support OR operator')
00108 
00109     where_clause = self.tokens[self.tokens.index('WHERE') + 1:]
00110     self.conditions = self._ParsePairs(' '.join(where_clause), 'AND')
00111     for cond in self.conditions:
00112       self.conditions[cond] = self._EscapeChars(self.conditions[cond])
00113     Debug('Conditions are [%s]' % self.conditions)
00114 
00115     # pop off where clause
00116     a = self.tokens.pop()
00117     while a != 'WHERE':
00118       a = self.tokens.pop()
00119 
00120     Debug('self.tokens: %s' % self.tokens)
00121 
00122   def _ParseColumns(self):
00123     """Parse out the columns that need to be selected.
00124 
00125     Raises:
00126       ParseError: Invalid SELECT syntax
00127     """
00128     self.columns = []
00129     self.count = False
00130     self.aliases = {}
00131     col_end = 0
00132     # this is only valid for SELECT
00133     if self.command != 'SELECT':
00134       return
00135 
00136     if 'FROM' not in self.tokens:
00137       raise ParseError('SELECT must be followed by FROM')
00138 
00139     col_end = self.tokens.index('FROM')
00140     if not col_end:  # col_end == 0
00141       raise ParseError('SELECT must be followed by column name[s]')
00142 
00143     cols_str = ' '.join(self.tokens[0:col_end])
00144     # check if there is a function modifier on the columns
00145     if self.tokens[0] == 'COUNT':
00146       self.count = True
00147       if col_end == 1:
00148         raise ParseError('COUNT must be followed by column name[s]')
00149       if not self.tokens[1].startswith(self._paren_placeholder):
00150         raise ParseError('COUNT must be followed by ()')
00151       cols_str = self._ReplaceParens(self.tokens[1])
00152 
00153     cols = cols_str.split(',')
00154     for col in cols:
00155       if not col.strip():
00156         raise ParseError('Extra comma in columns')
00157       col_split = col.split()
00158       if col_split[0] == 'CONCAT':
00159         # found a concat statement, do the same overall steps for those cols
00160         self._ParseColumnsConcatHelper(col_split)
00161       else:
00162         col_split = col.split()
00163         if len(col_split) > 2 and col_split[1] != 'AS':
00164           raise ParseError('multiple columns must be separated by a comma')
00165         elif len(col_split) == 3:
00166           if col_split[1] != 'AS':
00167             raise ParseError('Invalid column alias, use AS')
00168           my_key = self._ReplaceStringLiterals(col_split[2], quotes=True)
00169           my_val = self._ReplaceStringLiterals(col_split[0], quotes=True)
00170           self.aliases[my_key] = [my_val]
00171           self.columns.append(my_key)
00172         elif len(col_split) > 3:
00173           raise ParseError('multiple columns must be separated by a comma')
00174         elif len(col_split) == 2:  # alias
00175           my_key = self._ReplaceStringLiterals(col_split[1], quotes=True)
00176           my_val = self._ReplaceStringLiterals(col_split[0], quotes=True)
00177           self.aliases[my_key] = [my_val]
00178           self.columns.append(my_key)
00179         else:
00180           col = self._ReplaceStringLiterals(col, quotes=True).strip()
00181           if not col:  # col == ''
00182             raise ParseError('empty column name not allowed')
00183 
00184           self.columns.append(col)
00185 
00186     # pop off all the columns related junk
00187     self.tokens = self.tokens[col_end + 1:]
00188 
00189     Debug('Columns: %s' % self.columns)
00190     Debug('Aliases: %s' % self.aliases)
00191     Debug('self.tokens: %s' % self.tokens)
00192 
00193   def _ParseColumnsConcatHelper(self, col_split):
00194     """Handles the columns being CONCAT'd together.
00195 
00196     Args:
00197       col_split: ['column', 'column']
00198 
00199     Raises:
00200       ParseError: invalid CONCAT()
00201     """
00202     concat_placeholder = '_'
00203     split_len = len(col_split)
00204     if split_len == 1:
00205       raise ParseError('CONCAT() must be followed by column name[s]')
00206     if not col_split[1].startswith(self._paren_placeholder):
00207       raise ParseError('CONCAT must be followed by ()')
00208     if split_len > 2:
00209       if split_len == 4 and col_split[2] != 'AS':
00210         raise ParseError('CONCAT() must be followed by an AS clause')
00211       if split_len > 5:
00212         raise ParseError('CONCAT() AS clause takes exactly 1 arg. '
00213                          'Extra args: [%s]' % (col_split[4:]))
00214       else:
00215         concat_placeholder = self._ReplaceStringLiterals(col_split[-1],
00216                                                          quotes=True)
00217 
00218     # make sure this place hodler is unique
00219     while concat_placeholder in self.aliases:
00220       concat_placeholder += '_'
00221     concat_cols_str = self._ReplaceParens(col_split[1])
00222     concat_cols = concat_cols_str.split(',')
00223     concat_col_list = []
00224     for concat_col in concat_cols:
00225       if ' ' in concat_col.strip():
00226         raise ParseError('multiple columns must be separated by a'
00227                          ' comma inside CONCAT()')
00228       concat_col = self._ReplaceStringLiterals(concat_col, quotes=True).strip()
00229       if not concat_col:
00230         raise ParseError('Attempting to CONCAT empty set')
00231       concat_col_list.append(concat_col)
00232 
00233     self.aliases[concat_placeholder] = concat_col_list
00234     self.columns.append(concat_placeholder)
00235 
00236   def _ParseTable(self):
00237     """Parse out the table name (multiple table names not supported).
00238 
00239     Raises:
00240       ParseError: Unable to parse table name
00241     """
00242     table_name = ''
00243     if (not self.tokens or  # len == 0
00244         (self.tokens[0] in self.RESERVED_WORDS and
00245          self.tokens[0] not in ['FROM', 'INTO'])):
00246       raise ParseError('Missing table name')
00247 
00248     # SELECT
00249     if self.command == 'SELECT':
00250       table_name = self.tokens.pop(0)
00251 
00252     # INSERT
00253     elif self.command == 'INSERT':
00254       table_name = self.tokens.pop(0)
00255       if table_name == 'INTO':
00256         table_name = self.tokens.pop(0)
00257 
00258     # DELETE
00259     elif self.command == 'DELETE':
00260       if self.tokens[0] != 'FROM':
00261         raise ParseError('DELETE command must be followed by FROM')
00262 
00263       self.tokens.pop(0)  # FROM
00264       table_name = self.tokens.pop(0)
00265 
00266     # UPDATE
00267     elif self.command == 'UPDATE':
00268       table_name = self.tokens.pop(0)
00269 
00270     if not self.table:
00271       self.table = table_name
00272 
00273     else:  # multiple queries detected, make sure they're against same table
00274       if self.table != table_name:
00275         raise ParseError('Table changed between queries! %s -> %s' %
00276                          (self.table, table_name))
00277     Debug('Table is [%s]' % self.table)
00278     Debug('self.tokens is %s' % self.tokens)
00279 
00280   def _ParseTargets(self):
00281     """Parse out name value pairs of columns and their values.
00282 
00283     Raises:
00284       ParseError: Unable to parse targets
00285     """
00286     self.targets = {}
00287     # UPDATE
00288     if self.command == 'UPDATE':
00289       if self.tokens.pop(0) != 'SET':
00290         raise ParseError('UPDATE command must be followed by SET')
00291 
00292       self.targets = self._ParsePairs(' '.join(self.tokens), ',')
00293 
00294     # INSERT
00295     if self.command == 'INSERT':
00296       if self.tokens[0] == 'SET':
00297         self.targets = self._ParsePairs(' '.join(self.tokens[1:]), ',')
00298 
00299       elif len(self.tokens) == 3 and self.tokens[1] == 'VALUES':
00300         if not self.tokens[0].startswith(self._paren_placeholder):
00301           raise ParseError('INSERT column names must be inside parens')
00302         if not self.tokens[2].startswith(self._paren_placeholder):
00303           raise ParseError('INSERT values must be inside parens')
00304 
00305         cols = self._ReplaceParens(self.tokens[0]).split(',')
00306         vals = self._ReplaceParens(self.tokens[2]).split(',')
00307 
00308         if len(cols) != len(vals):
00309           raise ParseError('INSERT column and value numbers must match')
00310         if not cols:  # len == 0
00311           raise ParseError('INSERT column number must be greater than 0')
00312 
00313         i = 0
00314         while i < len(cols):
00315           val = vals[i].strip()
00316           if not val:  # val == ''
00317             raise ParseError('INSERT values cannot be empty')
00318           if ' ' in val:
00319             raise ParseError('INSERT values must be comma separated')
00320           self.targets[cols[i].strip()] = self._ReplaceStringLiterals(val)
00321           i += 1
00322 
00323       else:
00324         raise ParseError('Unable to parse INSERT targets')
00325 
00326     for target in self.targets:
00327       self.targets[target] = self._EscapeChars(self.targets[target])
00328 
00329     Debug('Targets are [%s]' % self.targets)
00330 
00331   def _EscapeChars(self, value):
00332     """Escape necessary chars before inserting into dbtext.
00333 
00334     Args:
00335       value: 'string'
00336 
00337     Returns:
00338       escaped: 'string' with chars escaped appropriately
00339     """
00340     # test that the value is string, if not return it as is
00341     try:
00342       value.find('a')
00343     except:
00344       return value
00345 
00346     escaped = value
00347     escaped = escaped.replace('\\', '\\\\').replace('\0', '\\0')
00348     escaped = escaped.replace(':', '\\:').replace('\n', '\\n')
00349     escaped = escaped.replace('\r', '\\r').replace('\t', '\\t')
00350     return escaped
00351 
00352   def _UnEscapeChars(self, value):
00353     """Un-escape necessary chars before returning to user.
00354 
00355     Args:
00356       value: 'string'
00357 
00358     Returns:
00359       escaped: 'string' with chars escaped appropriately
00360     """
00361     # test that the value is string, if not return it as is
00362     try:
00363       value.find('a')
00364     except:
00365       return value
00366 
00367     escaped = value
00368     escaped = escaped.replace('\\:', ':').replace('\\n', '\n')
00369     escaped = escaped.replace('\\r', '\r').replace('\\t', '\t')
00370     escaped = escaped.replace('\\0', '\0').replace('\\\\', '\\')
00371     return escaped
00372 
00373   def Execute(self, query, writethru=True):
00374     """Parse and execute the query.
00375 
00376     Args:
00377       query: e.g. 'select * from table;'
00378       writethru: bool
00379 
00380     Returns:
00381       dataset: [{col: val, col: val}, {col: val}, {col: val}]
00382 
00383     Raises:
00384       ExecuteError: unable to execute query
00385     """
00386     # parse the query
00387     self.ParseQuery(query)
00388 
00389     # get lock and execute the query
00390     self.OpenTable()
00391     Debug('Running ' + self.command)
00392     dataset = []
00393     if self.command == 'SELECT':
00394       dataset = self._RunSelect()
00395     elif self.command == 'UPDATE':
00396       dataset = self._RunUpdate()
00397     elif self.command == 'INSERT':
00398       dataset = self._RunInsert()
00399     elif self.command == 'DELETE':
00400       dataset = self._RunDelete()
00401 
00402     if self.command != 'SELECT' and writethru:
00403       self.WriteTempTable()
00404       self.MoveTableIntoPlace()
00405 
00406     Debug(dataset)
00407     return dataset
00408 
00409   def CleanUp(self):
00410     """Reset the internal variables (for multiple queries)."""
00411     self.tokens = []          # query broken up into tokens
00412     self.conditions = {}      # args to the WHERE clause
00413     self.columns = []         # columns requested by SELECT
00414     self.table = ''           # name of the table being queried
00415     self.header = {}          # table header
00416     self.orig_data = []       # original table data used to diff after updates
00417     self.data = []            # table data as a list of dicts
00418     self.count = False        # where or not using COUNT()
00419     self.aliases = {}         # column aliases (SELECT AS)
00420     self.targets = {}         # target columns-value pairs for INSERT/UPDATE
00421     self.args = ''            # query arguments preceeding the ;
00422     self.command = ''         # which command are we executing
00423     self.strings = []         # list of string literals parsed from the query
00424     self.parens = []          # list of parentheses parsed from the query
00425 
00426   def ParseQuery(self, query):
00427     """External wrapper for the query parsing routines.
00428 
00429     Args:
00430       query: string
00431 
00432     Raises:
00433       ParseError: Unable to parse query
00434     """
00435     self.args = query.split(';')[0]
00436     self._Tokenize()
00437     self._ParseCommand()
00438     self._ParseOrderBy()
00439     self._ParseConditions()
00440     self._ParseColumns()
00441     self._ParseTable()
00442     self._ParseTargets()
00443 
00444   def _ParseCommand(self):
00445     """Determine the command: SELECT, UPDATE, DELETE or INSERT.
00446 
00447     Raises:
00448       ParseError: unable to parse command
00449     """
00450     self.command = self.tokens[0]
00451     # Check that command is valid
00452     if self.command not in self.ALL_COMMANDS:
00453       raise ParseError('Unsupported command: ' + self.command)
00454 
00455     self.tokens.pop(0)
00456     Debug('Command is: %s' % self.command)
00457     Debug('self.tokens: %s' % self.tokens)
00458 
00459   def _Tokenize(self):
00460     """Turn the string query into a list of tokens.
00461 
00462     Split on '(', ')', ' ', ';', '=' and ','.
00463     In addition capitalize any SQL keywords found.
00464     """
00465     # horrible hack to handle now()
00466     time_now = '%s' % int(time.time())
00467     time_now = time_now[0:-2] + '00'  # round off the seconds for unittesting
00468     while 'now()' in self.args.lower():
00469       start = self.args.lower().find('now()')
00470       self.args = ('%s%s%s' % (self.args[0:start], time_now,
00471                                self.args[start + 5:]))
00472     # pad token separators with spaces
00473     pad = self.args.replace('(', ' ( ').replace(')', ' ) ')
00474     pad = pad.replace(',', ' , ').replace(';', ' ; ').replace('=', ' = ')
00475     self.args = pad
00476     # parse out all the blocks (string literals and parens)
00477     self._ParseOutBlocks()
00478     # split remaining into tokens
00479     self.tokens = self.args.split()
00480 
00481     # now capitalize
00482     i = 0
00483     while i < len(self.tokens):
00484       if self.tokens[i].upper() in self.RESERVED_WORDS:
00485         self.tokens[i] = self.tokens[i].upper()
00486 
00487       i += 1
00488 
00489     Debug('Tokens: %s' % self.tokens)
00490 
00491   def _ParseOutBlocks(self):
00492     """Parse out string literals and parenthesized values."""
00493     self.strings = []
00494     self.parens = []
00495 
00496     # set str placeholder to a value that's not present in the string
00497     while self._str_placeholder in self.args:
00498       self._str_placeholder = '%s_' % self._str_placeholder
00499 
00500     # set paren placeholder to a value that's not present in the string
00501     while self._paren_placeholder in self.args:
00502       self._paren_placeholder = '%s_' % self._paren_placeholder
00503 
00504     self.strings = self._ParseOutHelper(self._str_placeholder, ["'", '"'],
00505                                         'quotes')
00506     self.parens = self._ParseOutHelper(self._paren_placeholder, ['(', ')'],
00507                                        'parens')
00508     Debug('Strings: %s' % self.strings)
00509     Debug('Parens: %s' % self.parens)
00510 
00511   def _ParseOutHelper(self, placeholder, delims, mode):
00512     """Replace all text within delims with placeholders.
00513 
00514     Args:
00515       placeholder: string
00516       delims: list of strings
00517       mode: string
00518           'parens': if there are 2 delims treat the first as opening
00519                     and second as closing, such as with ( and )
00520           'quotes': treat each delim as either opening or
00521                     closing and require the same one to terminate the block,
00522                     such as with ' and "
00523 
00524     Returns:
00525       list: [value1, value2, ...]
00526 
00527     Raises:
00528       ParseError: unable to parse out delims
00529       ExecuteError: Invalid usage
00530     """
00531     if mode not in ['quotes', 'parens']:
00532       raise ExecuteError('_ParseOutHelper: invalid mode ' + mode)
00533     if mode == 'parens' and len(delims) != 2:
00534       raise ExecuteError('_ParseOutHelper: delims must have 2 values '
00535                          'in "parens" mode')
00536     values = []
00537     started = 0
00538     new_args = ''
00539     string = ''
00540     my_id = 0
00541     delim = ''
00542     for c in self.args:
00543       if c in delims:
00544         if not started:
00545           if mode == 'parens' and c != delims[0]:
00546             raise ParseError('Found closing delimeter %s before '
00547                              'corresponding %s' % (c, delims[0]))
00548           started += 1
00549           delim = c
00550         else:
00551           if ((mode == 'parens' and c == delim) or
00552               (mode == 'quotes' and c != delim)):
00553             string = '%s%s' % (string, c)
00554             continue  # wait for matching delim
00555 
00556           started -= 1
00557           if not started:
00558             values.append(string)
00559             new_args = '%s %s' % (new_args, '%s%d' % (placeholder, my_id))
00560             my_id += 1
00561             string = ''
00562 
00563       else:
00564         if not started:
00565           new_args = '%s%s' % (new_args, c)
00566         else:
00567           string = '%s%s' % (string, c)
00568 
00569     if started:
00570       if mode == 'parens':
00571         waiting_for = delims[1]
00572       else:
00573         waiting_for = delim
00574       raise ParseError('Unterminated block, waiting for ' + waiting_for)
00575 
00576     self.args = new_args
00577     Debug('Values: %s' % values)
00578     return values
00579 
00580   def _ReplaceStringLiterals(self, s, quotes=False):
00581     """Replaces string placeholders with real values.
00582 
00583     If quotes is set to True surround the returned value with single quotes
00584 
00585     Args:
00586       s: string
00587       quotes: bool
00588 
00589     Returns:
00590       s: string
00591     """
00592     if s.strip().startswith(self._str_placeholder):
00593       str_index = int(s.split(self._str_placeholder)[1])
00594       s = self.strings[str_index]
00595       if quotes:
00596         s = "'" + s + "'"
00597 
00598     return s
00599 
00600   def _ReplaceParens(self, s):
00601     """Replaces paren placeholders with real values.
00602 
00603     Args:
00604       s: string
00605 
00606     Returns:
00607       s: string
00608     """
00609     if s.strip().startswith(self._paren_placeholder):
00610       str_index = int(s.split(self._paren_placeholder)[1])
00611       s = self.parens[str_index].strip()
00612 
00613     return s
00614 
00615   def _RunDelete(self):
00616     """Run the DELETE command.
00617 
00618     Go through the rows in self.data matching them
00619     against the conditions, if they fit delete the row leaving a placeholder
00620     value (in order to keep the iteration process sane).  Afterward clean up
00621     any empty values.
00622 
00623     Returns:
00624       dataset: [number of affected rows]
00625     """
00626     i = 0
00627     length = len(self.data)
00628     affected = 0
00629     while i < length:
00630       if self._MatchRow(self.data[i]):
00631         self.data[i] = None
00632         affected += 1
00633 
00634       i += 1
00635 
00636     # clean out the placeholders
00637     while None in self.data:
00638       self.data.remove(None)
00639 
00640     return [affected]
00641 
00642   def _RunUpdate(self):
00643     """Run the UPDATE command.
00644 
00645     Find the matching rows and update based on self.targets
00646 
00647     Returns:
00648       affected: [int]
00649     Raises:
00650       ExecuteError: failed to run UPDATE
00651     """
00652     i = 0
00653     length = len(self.data)
00654     affected = 0
00655     while i < length:
00656       if self._MatchRow(self.data[i]):
00657         for target in self.targets:
00658           if target not in self.header:
00659             raise ExecuteError(target + ' is an invalid column name')
00660           if self.header[target]['auto']:
00661             raise ExecuteError(target + ' is type auto and connot be updated')
00662 
00663           self.data[i][target] = self._TypeCheck(self.targets[target], target)
00664         affected += 1
00665 
00666       i += 1
00667 
00668     return [affected]
00669 
00670   def _RunInsert(self):
00671     """Run the INSERT command.
00672 
00673     Build up the row based on self.targets and table defaults, then append to
00674     self.data
00675 
00676     Returns:
00677       affected: [int]
00678     Raises:
00679       ExecuteError: failed to run INSERT
00680     """
00681     new_row = {}
00682     cols = self._SortHeaderColumns()
00683     for col in cols:
00684       if col in self.targets:
00685         if self.header[col]['auto']:
00686           raise ExecuteError(col + ' is type auto: cannot be modified')
00687         new_row[col] = self.targets[col]
00688 
00689       elif self.header[col]['null']:
00690         new_row[col] = ''
00691 
00692       elif self.header[col]['auto']:
00693         new_row[col] = self._GetNextAuto(col)
00694 
00695       else:
00696         raise ExecuteError(col + ' cannot be empty or null')
00697 
00698     self.data.append(new_row)
00699     return [1]
00700 
00701   def _GetNextAuto(self, col):
00702     """Figure out the next value for col based on existing values.
00703 
00704     Scan all the current values and return the highest one + 1.
00705 
00706     Args:
00707       col: string
00708 
00709     Returns:
00710       next: int
00711 
00712     Raises:
00713       ExecuteError: Failed to get auto inc
00714     """
00715     highest = 0
00716     seen = []
00717     for row in self.data:
00718       if row[col] > highest:
00719         highest = row[col]
00720 
00721       if row[col] not in seen:
00722         seen.append(row[col])
00723       else:
00724         raise ExecuteError('duplicate value %s in %s' % (row[col], col))
00725 
00726     return highest + 1
00727 
00728   def _RunSelect(self):
00729     """Run the SELECT command.
00730 
00731     Returns:
00732       dataset: []
00733 
00734     Raises:
00735       ExecuteError: failed to run SELECT
00736     """
00737     dataset = []
00738     if ['*'] == self.columns:
00739       self.columns = self._SortHeaderColumns()
00740 
00741     for row in self.data:
00742       if self._MatchRow(row):
00743         match = []
00744         for col in self.columns:
00745           if col in self.aliases:
00746             concat = ''
00747             for concat_col in self.aliases[col]:
00748               if concat_col.startswith("'") and concat_col.endswith("'"):
00749                 concat += concat_col.strip("'")
00750               elif concat_col not in self.header.keys():
00751                 raise ExecuteError('Table %s does not have a column %s' %
00752                                    (self.table, concat_col))
00753               else:
00754                 concat = '%s%s' % (concat, row[concat_col])
00755 
00756             if not concat.strip():
00757               raise ExecuteError('Empty CONCAT statement')
00758 
00759             my_match = concat
00760 
00761           elif col.startswith("'") and col.endswith("'"):
00762             my_match = col.strip("'")
00763           elif col not in self.header.keys():
00764             raise ExecuteError('Table %s does not have a column %s' %
00765                                (self.table, col))
00766           else:
00767             my_match = row[col]
00768 
00769           match.append(self._UnEscapeChars(my_match))
00770 
00771         dataset.append(match)
00772 
00773     if self.count:
00774       Debug('Dataset: %s' % dataset)
00775       dataset = [len(dataset)]
00776 
00777     if self.order_by:
00778       if self.order_by not in self.header.keys():
00779         raise ExecuteError('Unknown column %s in ORDER BY clause' %
00780                            self.order_by)
00781       pos = self._PositionByCol(self.order_by)
00782       dataset = self._SortMatrixByCol(dataset, pos)
00783 
00784     return dataset
00785 
00786   def _SortMatrixByCol(self, dataset, pos):
00787     """Sorts the matrix (array or arrays) based on a given column value.
00788 
00789     That is, if given matrix that looks like:
00790 
00791     [[1, 2, 3], [6, 5, 4], [3, 2, 1]]
00792 
00793     given pos = 0 produce:
00794 
00795     [[1, 2, 3], [3, 2, 1], [6, 5, 4]]
00796 
00797     given pos = 1 produce:
00798 
00799     [[1, 2, 3], [3, 2, 1], [6, 5, 4]]
00800 
00801     given pos = 2 produce:
00802 
00803     [[3, 2, 1], [1, 2, 3], [6, 5, 4]]
00804 
00805     Works for both integer and string values of column.
00806 
00807     Args:
00808       dataset: [[], [], ...]
00809       pos: int
00810 
00811     Returns:
00812       sorted: [[], [], ...]
00813     """
00814     # prepend value in pos to the beginning of every row
00815     i = 0
00816     while i < len(dataset):
00817       dataset[i].insert(0, dataset[i][pos])
00818       i += 1
00819 
00820     # sort the matrix, which is done on the row we just prepended
00821     dataset.sort()
00822 
00823     # strip away the first value
00824     i = 0
00825     while i < len(dataset):
00826       dataset[i].pop(0)
00827       i += 1
00828 
00829     return dataset
00830 
00831   def _MatchRow(self, row):
00832     """Matches the row against self.conditions.
00833 
00834     Args:
00835       row: ['val', 'val']
00836 
00837     Returns:
00838       Bool
00839     """
00840     match = True
00841     # when there are no conditions we match everything
00842     if not self.conditions:
00843       return match
00844 
00845     for condition in self.conditions:
00846       cond_val = self.conditions[condition]
00847       if condition not in self.header.keys():
00848         match = False
00849         break
00850       else:
00851         if cond_val != row[condition]:
00852           match = False
00853           break
00854 
00855     return match
00856 
00857   def _ProcessHeader(self):
00858     """Parse out the header information.
00859 
00860     Returns:
00861       {col_name: {'type': string, 'null': string, 'auto': string, 'pos': int}}
00862     """
00863     header = self.fd.readline().strip()
00864     cols = {}
00865     pos = 0
00866     for col in header.split():
00867       col_name = col.split('(')[0]
00868       col_type = col.split('(')[1].split(')')[0].split(',')[0]
00869       col_null = False
00870       col_auto = False
00871       if ',' in col.split('(')[1].split(')')[0]:
00872         if col.split('(')[1].split(')')[0].split(',')[1].lower() == 'null':
00873           col_null = True
00874         if col.split('(')[1].split(')')[0].split(',')[1].lower() == 'auto':
00875           col_auto = True
00876 
00877       cols[col_name] = {}
00878       cols[col_name]['type'] = col_type
00879       cols[col_name]['null'] = col_null
00880       cols[col_name]['auto'] = col_auto
00881       cols[col_name]['pos'] = pos
00882       pos += 1
00883 
00884     return cols
00885 
00886   def _GetData(self):
00887     """Reads table data into memory as a list of dicts keyed on column names.
00888 
00889     Returns:
00890       data: [{row}, {row}, ...]
00891     Raises:
00892       ExecuteError: failed to get data
00893     """
00894     data = []
00895     row_num = 0
00896     for row in self.fd:
00897       row = row.rstrip('\n')
00898       row_dict = {}
00899       i = 0
00900       field_start = 0
00901       field_num = 0
00902       while i < len(row):
00903         if row[i] == ':':
00904           # the following block is executed again after the while is done
00905           val = row[field_start:i]
00906           col = self._ColByPosition(field_num)
00907           val = self._TypeCheck(val, col)
00908           row_dict[col] = val
00909 
00910           field_start = i + 1  # skip the colon itself
00911           field_num += 1
00912         if row[i] == '\\':
00913           i += 2  # skip the next char since it's escaped
00914         else:
00915           i += 1
00916 
00917       # handle the last field since we won't hit a : at the end
00918       # sucks to duplicate the code outside the loop but I can't think
00919       # of a better way :(
00920 
00921       val = row[field_start:i]
00922       col = self._ColByPosition(field_num)
00923       val = self._TypeCheck(val, col)
00924       row_dict[col] = val
00925 
00926       # verify that all columns were created
00927       for col in self.header:
00928         if col not in row_dict:
00929           raise ExecuteError('%s is missing from row %d in %s' %
00930                              (col, row_num, self.table))
00931 
00932       row_num += 1
00933       data.append(row_dict)
00934 
00935     return data
00936 
00937   def _TypeCheck(self, val, col):
00938     """Verify type of val based on the header.
00939 
00940     Make sure the value is returned in quotes if it's a string
00941     and as '' when it's empty and Null
00942 
00943     Args:
00944       val: string
00945       col: string
00946 
00947     Returns:
00948       val: string
00949 
00950     Raises:
00951       ExecuteError: invalid value or column
00952     """
00953     if not val and not self.header[col]['null']:
00954       raise ExecuteError(col + ' cannot be empty or null')
00955 
00956     if (self.header[col]['type'].lower() == 'int' or
00957         self.header[col]['type'].lower() == 'double'):
00958       try:
00959         if val:
00960           val = eval(val)
00961       except NameError, e:
00962         raise ExecuteError('Failed to parse %s in %s '
00963                            '(unable to convert to type %s): %s' %
00964                            (col, self.table, self.header[col]['type'], e))
00965       except SyntaxError, e:
00966         raise ExecuteError('Failed to parse %s in %s '
00967                            '(unable to convert to type %s): %s' %
00968                            (col, self.table, self.header[col]['type'], e))
00969 
00970     return val
00971 
00972   def _ColByPosition(self, pos):
00973     """Returns column name based on position.
00974 
00975     Args:
00976       pos: int
00977 
00978     Returns:
00979       column: string
00980 
00981     Raises:
00982       ExecuteError: invalid column
00983     """
00984     for col in self.header:
00985       if self.header[col]['pos'] == pos:
00986         return col
00987 
00988     raise ExecuteError('Header does not contain column %d' % pos)
00989 
00990   def _PositionByCol(self, col):
00991     """Returns position of the column based on the name.
00992 
00993     Args:
00994       col: string
00995 
00996     Returns:
00997       pos: int
00998 
00999     Raises:
01000       ExecuteError: invalid column
01001     """
01002     if col not in self.header.keys():
01003       raise ExecuteError(col + ' is not a valid column name')
01004 
01005     return self.header[col]['pos']
01006 
01007   def _SortHeaderColumns(self):
01008     """Sort column names by position.
01009 
01010     Returns:
01011       sorted: [col1, col2, ...]
01012 
01013     Raises:
01014       ExecuteError: unable to sort header
01015     """
01016     cols = self.header.keys()
01017     sorted_cols = [''] * len(cols)
01018     for col in cols:
01019       pos = self.header[col]['pos']
01020       sorted_cols[pos] = col
01021 
01022     if '' in sorted_cols:
01023       raise ExecuteError('Unable to sort header columns: %s' % cols)
01024 
01025     return sorted_cols
01026 
01027   def OpenTable(self):
01028     """Opens the table file and places its content into memory.
01029 
01030     Raises:
01031       ExecuteError: unable to open table
01032     """
01033     # if we already have a header assume multiple queries on same table
01034     # (can't use self.data in case the table was empty to begin with)
01035     if self.header:
01036       return
01037 
01038     try:
01039       self.fd = open(os.path.join(self.location, self.table), 'r')
01040       self.header = self._ProcessHeader()
01041 
01042       if self.command in ['INSERT', 'DELETE', 'UPDATE']:
01043         fcntl.flock(self.fd, fcntl.LOCK_EX)
01044 
01045       self.data = self._GetData()
01046       self.orig_data = self.data[:]  # save a copy of the data before modifying
01047 
01048     except IOError, e:
01049       raise ExecuteError('Unable to open table %s: %s' % (self.table, e))
01050 
01051     Debug('Header is: %s' % self.header)
01052 
01053     # type check the conditions
01054     for cond in self.conditions:
01055       if cond not in self.header.keys():
01056         raise ExecuteError('unknown column %s in WHERE clause' % cond)
01057       self.conditions[cond] = self._TypeCheck(self.conditions[cond], cond)
01058 
01059     # type check the targets
01060     for target in self.targets:
01061       if target not in self.header.keys():
01062         raise ExecuteError('unknown column in targets:  %s' % target)
01063       self.targets[target] = self._TypeCheck(self.targets[target], target)
01064 
01065     Debug('Type checked conditions: %s' % self.conditions)
01066 
01067     Debug('Data is:')
01068     for row in self.data:
01069       Debug('=======================')
01070       Debug(row)
01071     Debug('=======================')
01072 
01073   def WriteTempTable(self):
01074     """Write table header and data.
01075 
01076     First write header and data to a temp file,
01077     then move the tmp file to replace the original table file.
01078     """
01079     self.temp_file = tempfile.NamedTemporaryFile()
01080     Debug('temp_file: ' + self.temp_file.name)
01081     # write header
01082     columns = self._SortHeaderColumns()
01083     header = ''
01084     for col in columns:
01085       header = '%s %s' % (header, col)
01086       header = '%s(%s' % (header, self.header[col]['type'])
01087       if self.header[col]['null']:
01088         header = '%s,null)' % header
01089       elif self.header[col]['auto']:
01090         header = '%s,auto)' % header
01091       else:
01092         header = '%s)' % header
01093 
01094     self.temp_file.write(header.strip() + '\n')
01095 
01096     # write data
01097     for row in self.data:
01098       row_str = ''
01099       for col in columns:
01100         row_str = '%s:%s' % (row_str, row[col])
01101 
01102       self.temp_file.write(row_str[1:] + '\n')
01103 
01104     self.temp_file.flush()
01105 
01106   def MoveTableIntoPlace(self):
01107     """Replace the real table with the temp one.
01108 
01109     Diff the new data against the original and replace the table when they are
01110     different.
01111     """
01112     if self.data != self.orig_data:
01113       temp_file = self.temp_file.name
01114       table_file = os.path.join(self.location, self.table)
01115       Debug('Copying %s to %s' % (temp_file, table_file))
01116       shutil.copy(self.temp_file.name, self.location + '/' + self.table)
01117 
01118   def _ParsePairs(self, s, delimeter):
01119     """Parses out name value pairs from a string.
01120 
01121     String contains name=value pairs
01122     separated by a delimiter (such as "and" or ",")
01123 
01124     Args:
01125       s: string
01126       delimeter: string
01127 
01128     Returns:
01129       my_dict: dictionary
01130 
01131     Raises:
01132       ParseError: unable to parse pairs
01133     """
01134     my_dict = {}
01135     Debug('parse pairs: [%s]' % s)
01136     pairs = s.split(delimeter)
01137     for pair in pairs:
01138       if '=' not in pair:
01139         raise ParseError('Invalid condition pair: ' + pair)
01140 
01141       split = pair.split('=')
01142       Debug('split: %s' % split)
01143       if len(split) != 2:
01144         raise ParseError('Invalid condition pair: ' + pair)
01145 
01146       col = split[0].strip()
01147       if not col or not split[1].strip() or ' ' in col:
01148         raise ParseError('Invalid condition pair: ' + pair)
01149 
01150       val = self._ReplaceStringLiterals(split[1].strip())
01151       my_dict[col] = val
01152 
01153     return my_dict
01154 
01155 
01156 class Error(Exception):
01157   """DBText error."""
01158 
01159 
01160 class ParseError(Error):
01161   """Parse error."""
01162 
01163 
01164 class NotSupportedError(Error):
01165   """Not Supported error."""
01166 
01167 
01168 class ExecuteError(Error):
01169   """Execute error."""
01170 
01171 
01172 def main(argv):
01173 
01174   if len(argv) < 2:
01175     print 'Usage %s query' % argv[0]
01176     sys.exit(1)
01177 
01178   if 'DBTEXT_PATH' not in os.environ or not os.environ['DBTEXT_PATH']:
01179     print 'DBTEXT_PATH must be set'
01180     sys.exit(1)
01181   else:
01182     location = os.environ['DBTEXT_PATH']
01183 
01184   try:
01185     conn = DBText(location)
01186     dataset = conn.Execute(' '.join(argv[1:]))
01187     if dataset:
01188       for row in dataset:
01189         if conn.command != 'SELECT':
01190           print 'Updated %s, rows affected: %d' % (conn.table, row)
01191         else:
01192           print row
01193   except Error, e:
01194     print e
01195     sys.exit(1)
01196 
01197 
01198 if __name__ == '__main__':
01199   main(sys.argv)