Designing Machine Learning Systems - Chip Huyen

Designing Machine Learning Systems

  • ML systems are both complex and unique. They are complex because they consist of many different components (ML algorithms, data, business logic, evaluation metrics, underlying infrastructure, etc.) and involve many different stakeholders (data scientists, ML engineers, business leaders, users, even society at large). ML systems are unique because they are data dependent, and data varies wildly from one use case to the next.
  • Ops in MLOps comes from DevOps, short for Developments and Operations. To operationalize something means to bring it into production, which includes deploying, monitoring, and maintaining it. MLOps is a set of tools and best practices for bringing ML into production.
  • Fraud detection is among the oldest applications of ML in the enterprise world. If your product or service involves transactions of any value, it’ll be susceptible to fraud.
  • Price optimization is the process of estimating a price at a certain time period to maximize a defined objective function, such as the company’s margin, revenue, or growth rate.
  • Reducing customer acquisition costs by a small amount can result in a large increase in profit. This can be done through better identifying potential customers, showing better-targeted ads, giving discounts at the right time, etc.—all of which are suitable tasks for ML.
  • The cost of acquiring a new user is approximated to be 5 to 25 times more expensive than retaining an existing one.12 Churn prediction is predicting when a specific customer is about to stop using your products or services so that you can take appropriate actions to win them back.
  • People involved in a research and leaderboard project often align on one single objective. The most common objective is model performance—develop a model that achieves the state-of-the-art results on benchmark datasets. To edge out a small improvement in performance, researchers often resort to techniques that make models too complex to be useful. There are many stakeholders involved in bringing an ML system into production. Each stakeholder has their own requirements. Having different, often conflicting, requirements can make it difficult to design, develop, and select an ML model that satisfies all the requirements.
  • For example, ensembling is a technique popular among the winners of many ML competitions, including the famed $1 million Netflix Prize, and yet it’s not widely used in production. Ensembling combines “multiple learning algorithms to obtain better predictive performance than could be obtained from any of the constituent learning algorithms alone.”15 While it can give your ML system a small performance improvement, ensembling tends to make a system too complex to be useful in production, e.g., slower to make predictions or harder to interpret the results. We’ll
  • One corollary of this is that research prioritizes high throughput whereas production prioritizes low latency. In case you need a refresh, latency refers to the time it takes from receiving a query to returning the result. Throughput refers to how many queries are processed within a specific period of time.
  • In SWE, there’s an underlying assumption that code and data are separated. In fact, in SWE, we want to keep things as modular and separate as possible (see the Wikipedia page on separation of concerns). On the contrary, ML systems are part code, part data, and part artifacts created from the two. The trend in the last decade shows that applications developed with the most/best data win. Instead of focusing on improving ML algorithms, most companies will focus on improving their data. Because data can change quickly, ML applications need to be adaptive to the changing
  • Before we develop an ML system, we must understand why this system is needed. If this system is built for a business, it must be driven by business objectives, which will need to be translated into ML objectives to guide the development of ML models.
  • One of the reasons why predicting ad click-through rates and fraud detection are among the most popular use cases for ML today is that it’s easy to map ML models’ performance to business metrics: every increase in click-through rate results in actual ad revenue, and every fraudulent transaction stopped results in actual money saved. Many companies create their own metrics to map business metrics to ML metrics. For example, Netflix measures the performance of their recommender system using take-rate: the number of quality plays divided by the number of recommendations a user sees.4 The higher the take-rate, the better the recommender system. Netflix also put a recommender system’s take-rate in the context of their other business metrics like total streaming hours and subscription cancellation rate. They found that a higher take-rate also results in higher total streaming hours and lower subscription cancellation rates.5
  • The specified requirements for an ML system vary from use case to use case. However, most systems should have these four characteristics: reliability, scalability, maintainability, and adaptability.
  • Reliability The system should continue to perform the correct function at the desired level of performance even in the face of adversity (hardware or software faults, and even human error).
  • It’s important to structure your workloads and set up your infrastructure in such a way that different contributors can work using tools that they are comfortable with, instead of one group of contributors forcing their tools onto other groups. Code should be documented. Code, data, and artifacts should be versioned. Models should be sufficiently reproducible so that even when the original authors are not around, other contributors can have sufficient contexts to build on their work. When a problem occurs, different contributors should be able to work together to identify the problem and implement a solution without finger-pointing. We’ll
  • Adaptability To adapt to shifting data distributions and business requirements, the system should have some capacity for both discovering aspects for performance improvement and allowing updates without service interruption. Because ML systems are part code, part data, and data can change quickly, ML systems need to be able to evolve quickly.
  • Choosing an objective function is usually straightforward, though not because objective functions are easy. Coming up with meaningful objective functions requires algebra knowledge, so most ML engineers just use common loss functions like RMSE or MAE (mean absolute error) for regression, logistic loss (also log loss) for binary classification, and cross entropy for multiclass classification.
  • More data beats algorithms
    Collecting more data is one of regularization methods
  • Progress in the last decade shows that the success of an ML system depends largely on the data it was trained on. Instead of focusing on improving ML algorithms, most companies focus on managing and improving their data.16 Despite the success of models using massive amounts of data, many are skeptical of the emphasis on data as the way forward. In the last five years, at every academic conference I attended, there were always some public debates on the power of mind versus data. Mind might be disguised as inductive biases or intelligent architectural designs. Data might be grouped together with computation since more data tends to require more computation. In theory, you can both pursue architectural designs and leverage large data and computation, but spending time on one often takes time away from another.17 In the mind-over-data camp, there’s Dr. Judea Pearl, a Turing Award winner best known for his work on causal inference and Bayesian networks. The introduction to his book The Book of Why is entitled “Mind over Data,” in which he emphasizes: “Data is profoundly dumb.”
  • How do I store multimodal data, e.g., a sample that might contain both images and texts? Where do I store my data so that it’s cheap and still fast to access? How do I store complex models so that they can be loaded and run correctly on different hardware?
  • The two formats that are common and represent two distinct paradigms are CSV and Parquet. CSV (comma-separated values) is row-major, which means consecutive elements in a row are stored next to each other in memory. Parquet is column-major, which means consecutive elements in a column are stored next to each other. Because modern computers process sequential data more efficiently than nonsequential data, if a table is row-major, accessing its rows will be faster than accessing its columns in expectation. This means that for row-major formats, accessing data by rows is expected to be faster than accessing data by columns.
  • With certain added features, SQL can be Turing-complete, which means that, in theory, SQL can be used to solve any computation problem (without making any guarantee about the time or memory required).
  • NoSQL The relational data model has been able to generalize to a lot of use cases, from ecommerce to finance to social networks. However, for certain use cases, this model can be restrictive. For example, it demands that your data follows a strict schema, and schema management is painful. In a survey by Couchbase in 2014, frustration with schema management was the #1 reason for the adoption of their nonrelational database.16 It can also be difficult to write and execute SQL queries for specialized applications. The latest movement against the relational data model is NoSQL. Originally started as a hashtag for a meetup to discuss nonrelational databases, NoSQL has been retroactively reinterpreted as Not Only SQL,17 as many NoSQL data systems also support relational models. Two major types of nonrelational models are the document model and the graph model. The document model targets use cases where data comes in self-contained documents and relationships between one document and another are rare. The graph model goes in the opposite direction, targeting use cases where relationships between data items are common and important. We’ll examine each of these two models, starting with the document model.
  • repository for storing structured data is called a data warehouse. A repository for storing unstructured data is called a data lake. Data lakes are usually used to store raw data before processing. Data warehouses are used to store data that has been processed into formats ready to be used. Table 3-5 shows a summary of the key differences between structured and unstructured data. Table 3-5. The key differences between structured and unstructure
  • A repository for storing structured data is called a data warehouse. A repository for storing unstructured data is called a data lake. Data lakes are usually used to store raw data before processing. Data warehouses are used to store data that has been processed into formats ready to be used.
  • Transactional and Analytical Processing Traditionally, a transaction refers to the action of buying or selling something. In the digital world, a transaction refers to any kind of action: tweeting, ordering a ride through a ride-sharing service, uploading a new model, watching a YouTube video, and so on. Even though these different transactions involve different types of data, the way they’re processed is similar across applications. The transactions are inserted as they are generated, and occasionally updated when something changes, or deleted when they are no longer needed.19 This type of processing is known as online transaction processing (OLTP). Because these transactions often involve users, they need to be processed fast (low latency) so that they don’t keep users waiting. The processing method needs to have high availability—that is, the processing system needs to be available any time a user wants to make a transaction. If your system can’t process a transaction, that transaction won’t go through. Transactional databases are designed to process online transactions and satisfy the low latency, high availability requirements. When people hear transactional databases, they usually think of ACID (atomicity, consistency, isolation, durability). Here are their definitions for those needing a quick reminder:
  • Because each transaction is often processed as a unit separately from other transactions, transactional databases are often row-major. This also means that transactional databases might not be efficient for questions such as “What’s the average price for all the rides in September in San Francisco?” This kind of analytical question requires aggregating data in columns across multiple rows of data. Analytical databases are designed for this purpose. They are efficient with queries that allow you to look at data from different viewpoints. We call this type of processing online analytical processing (OLAP). However,
  • An interesting paradigm in the last decade has been to decouple storage from processing (also known as compute), as adopted by many data vendors including Google’s BigQuery, Snowflake, IBM, and Teradata.21 In this paradigm, the data can be stored in the same place, with a processing layer on top that can be optimized for different types of queries.
  • Most of the time, in production, you don’t have a single process but multiple. A question arises: how do we pass data between different processes that don’t share memory? When data is passed from one process to another, we say that the data flows from one process to another, which gives us a dataflow. There are three main modes of dataflow: Data passing through databases Data passing through services using requests such as the requests provided by REST and RPC APIs (e.g., POST/GET requests) Data passing through a real-time transport like Apache Kafka and Amazon Kinesis We’ll go over each of them in this section.
  • To put the microservice architecture in the context of ML systems, imagine you’re an ML engineer working on the price optimization problem for a company that owns a ride-sharing application like Lyft. In reality, Lyft has hundreds of services in its microservice architecture,
  • Microservice architecture example
  • The most popular styles of requests used for passing data through networks are REST (representational state transfer) and RPC (remote procedure call). Their detailed analysis is beyond the scope of this book, but one major difference is that REST was designed for requests over networks, whereas RPC “tries to make a request to a remote network service look the same as calling a function or method in your programming language.” Because of this, “REST seems to be the predominant style for public APIs. The main focus of RPC frameworks is on requests between services owned by the same organization, typically within the same data center.”25 Implementations of a REST architecture are said to be RESTful. Even though many people think of REST as HTTP, REST doesn’t exactly mean HTTP because HTTP is just an implementation of REST.26
  • Request-driven architecture works well for systems that rely more on logic than on data. Event-driven architecture works better for systems that are data-heavy.
  • Because batch processing happens much less frequently than stream processing, in ML, batch processing is usually used to compute features that change less often, such as drivers’ ratings (if a driver has had hundreds of rides, their rating is less likely to change significantly from one day to the next). Batch features—features extracted through batch processing—are also known as static features. Stream processing is used to compute features that change quickly, such as how many drivers are available right now, how many rides have been requested in the last minute, how many rides will be finished in the next two minutes, the median price of the last 10 rides in this area, etc. Features about the current state of the system like these are important to make the optimal price predictions. Streaming features—features extracted through stream processing—are also known as dynamic features.
  • We use the term “training data” instead of “training dataset” because “dataset” denotes a set that is finite and stationary. Data in production is neither finite nor stationary, a phenomenon that we will cover in the section “Data Distribution Shifts”. Like other steps in building ML systems, creating training data is an iterative process. As your model evolves through a project lifecycle, your training data will likely also evolve.
  • In this section, we will discuss the challenge of obtaining labels for your data. We’ll first discuss the labeling method that usually comes first in data scientists’ mind when talking about labeling: hand-labeling. We will then discuss tasks with natural labels, which are tasks where labels can be inferred from the system without requiring human annotations, followed by what to do when natural and hand labels are lacking.
  • It’s good practice to keep track of the origin of each of your data samples as well as its labels, a technique known as data lineage. Data lineage helps you both flag potential biases in your data and debug your models. For example, if your model fails mostly on the recently acquired data samples, you might want to look into how the new data was acquired. On more than one occasion, we’ve discovered that the problem wasn’t with our model, but because of the unusually high number of wrong labels in the data that we’d acquired recently.
  • The canonical example of tasks with natural labels is recommender systems. The goal of a recommender system is to recommend to users items relevant to them. Whether a user clicks on the recommended item or not can be seen as the feedback for that recommendation. A recommendation that gets clicked on can be presumed to be good (i.e., the label is POSITIVE) and a recommendation that doesn’t get clicked on after a period of time, say 10 minutes, can be presumed to be bad (i.e., the label is NEGATIVE). Many tasks can be framed as recommendation tasks. For example, you can frame the task of predicting ads’ click-through rates as recommending the most relevant ads to users based on their activity histories and profiles. Natural labels that are inferred from user behaviors like clicks and ratings are also known as behavioral labels.
  • Weak supervision If hand labeling is so problematic, what if we don’t use hand labels altogether? One approach that has gained popularity is weak supervision. One of the most popular open source tools for weak supervision is Snorkel, developed at the Stanford AI Lab.11 The insight behind weak supervision is that people rely on heuristics, which can be developed with subject matter expertise, to label data. For example, a doctor might use the following heuristics to decide whether a patient’s case should be prioritized as emergent: If the nurse’s note mentions a serious condition like pneumonia, the patient’s case should be given priority consideration.
  • A classic semi-supervision method is self-training. You start by training a model on your existing set of labeled data and use this model to make predictions for unlabeled samples. Assuming that predictions with high raw probability scores are correct, you add the labels predicted with high probability to your training set and train a new model on this expanded training set. This goes on until you’re happy with your model performance.
  • Class imbalance can also happen with regression tasks where the labels are continuous. Consider the task of estimating health-care bills.25 Health-care bills are highly skewed—the median bill is low, but the 95th percentile bill is astronomical.
  • For example, misclassification on an X-ray with cancerous cells is much more dangerous than misclassification on an X-ray of a normal lung. If your loss function isn’t configured to address this asymmetry, your model will treat all samples the same way.
  • popular method of undersampling low-dimensional data that was developed back in 1976 is Tomek links.38 With this technique, you find pairs of samples from opposite classes that are close in proximity and remove the sample of the majority class in each pair.
  • A popular method of oversampling low-dimensional data is SMOTE (synthetic minority oversampling technique).39 It synthesizes novel samples of the minority class through sampling convex combinations of existing data points within the minority class.40 Both SMOTE and Tomek links have only been proven effective in low-dimensional data.
  • One such technique is two-phase learning.42 You first train your model on the resampled data. This resampled data can be achieved by randomly undersampling large classes until each class has only N instances. You then fine-tune your model on the original data. Another technique is dynamic sampling: oversample the low-performing classes and undersample the high-performing classes during the training process. Introduced by Pouyanfar et al.,43 the method aims to show the model less of what it has already learned and more of what it has not.
  • Cost-sensitive learning Back in 2001, based on the insight that misclassification of different classes incurs different costs, Elkan proposed cost-sensitive learning in which the individual loss function is modified to take into account this varying cost.
  • Class-balanced loss What might happen with a model trained on an imbalanced dataset is that it’ll bias toward majority classes and make wrong predictions on minority classes. What if we punish the model for making wrong predictions on minority classes to correct this bias?
  • this section, we will cover three main types of data augmentation: simple label-preserving transformations; perturbation, which is a term for “adding noises”; and data synthesis. In each type, we’ll go over examples for both computer vision and NLP.
  • Before deep learning, when given a piece of text, you would have to manually apply classical text processing techniques such as lemmatization, expanding contractions, removing punctuation, and lowercasing everything. After that, you might want to split your text into n-grams with n values of your choice.
  • several of the most important operations that you might want to consider while engineering features from your data. They include handling missing values, scaling, discretization, encoding categorical features, and generating the old-school but still very effective cross features as well as the newer and exciting positional features.
  • One way to delete is column deletion: if a variable has too many missing values, just remove that variable. For example, in the example above, over 50% of the values for the variable “Marital status” are missing, so you might be tempted to remove this variable from your model.
  • Another way to delete is row deletion: if a sample has missing value(s), just remove that sample. This method can work when the missing values are completely at random (MCAR) and the number of examples with missing values is small, such as less than 0.1%. You don’t want to do row deletion if that means 10% of your data samples are removed.
  • Before inputting features into models, it’s important to scale them to be similar ranges. This process is called feature scaling.
  • In practice, ML models tend to struggle with features that follow a skewed distribution. To help mitigate the skewness, a technique commonly used is log transformation: apply the log function to your feature. An example of how the log transformation can make your data less skewed is shown in Figure 5-3. While this technique can yield performance gain in many cases, it doesn’t work for all cases, and you should be wary of the analysis performed on log-transformed data instead of the original data.5
  • However, in production, categories change. Imagine you’re building a recommender system to predict what products users might want to buy from Amazon. One of the features you want to use is the product brand. When looking at Amazon’s historical data, you realize that there are a lot of brands. Even back in 2019, there were already over two million brands on Amazon!6
  • if you want to predict whether a comment is spam, you might want to use the account that posted this comment as a feature, and new accounts are being created all the time. The same goes for new product types, new website domains, new restaurants, new companies, new IP addresses, and so on. If you work with any of them, you’ll have to deal with this problem.
  • One solution to this problem is the hashing trick, popularized by the package Vowpal Wabbit developed at Microsoft.7 The gist of this trick is that you use a hash function to generate a hashed value of each category. The hashed value will become the index of that category. Because you can specify the hash space, you can fix the number of encoded values for a feature in advance, without having to know how many categories there will be.
  • Because feature crossing helps model nonlinear relationships between variables, it’s essential for models that can’t learn or are bad at learning nonlinear relationships, such as linear regression, logistic regression, and tree-based models. It’s less important in neural networks, but it can still be useful because explicit feature crossing occasionally helps neural networks learn nonlinear relationships faster.
  • Discrete and Continuous Positional Embeddings First introduced to the deep learning community in the paper “Attention Is All You Need” (Vaswani et al. 2017), positional embedding has become a standard data engineering technique for many applications in both computer vision and NLP. We’ll walk through an example to show why positional embedding is necessary and how to do it.
  • we use a recurrent neural network, it will process words in sequential order, which means the order of words is implicitly inputted. However, if we use a model like a transformer, words are processed in parallel, so words’ positions need to be explicitly inputted so that our model knows the order of these words (“a dog bites a child” is very different from “a child bites a dog”). We don’t want to input the absolute positions, 0, 1, 2, 
, 7, into our model because empirically, neural networks don’t work well with inputs that aren’t unit-variance (that’s why we scale our features, as discussed previously in the section “Scaling”). If we rescale the positions to between 0 and 1, so 0, 1, 2, 
, 7 become 0, 0.143, 0.286, 
, 1, the differences between the two positions will be too small for neural networks to learn to differentiate. A way to handle position embeddings is to treat it the way we’d treat word
  • Data Leakage In July 2021, MIT Technology Review ran a provocative article titled “Hundreds of AI Tools Have Been Built to Catch Covid. None of Them Helped.” These models were trained to predict COVID-19 risks from medical scans. The article listed multiple examples where ML models that performed well during evaluation failed to be usable in actual production settings. In one example, researchers trained their model on a mix of scans taken when patients were lying down and standing up. “Because patients scanned while lying down were more likely to be seriously ill, the model learned to predict serious covid risk from a person’s position.” In some other cases, models were “found to be picking up on the text font that certain hospitals used to label the scans. As a result, fonts from hospitals with more serious caseloads became predictors of covid risk.”12
  • Common Causes for Data Leakage In this section, we’ll go over some common causes for data leakage and how to avoid them. Splitting time-correlated data randomly instead of by time
  • To oversimplify it, the prices of similar stocks tend to move together. If 90% of the tech stocks go down today, it’s very likely the other 10% of the tech stocks go down too. When building models to predict the future stock prices, you want to split your training data by time, such as training your model on data from the first six days and evaluating it on data from the seventh day. If you randomly split your data, prices from the seventh day will be included in your train split and leak into your model the condition of the market on that day.
  • predicting whether someone will click on a song recommendation. Whether someone will listen to a song depends not only on their music taste but also on the general music trend that day. If an artist passes away one day, people will be much more likely to listen to that artist. By including samples from a certain day in the train split, information about the music trend that day will be passed into your model, making it easier for it to make predictions on other samples on that same day.
  • Scaling before splitting
  • Poor handling of data duplication before splitting
  • Detecting Data Leakage
  • Measure the predictive power of each feature or a set of features with respect to the target variable (label). If a feature has unusually high correlation, investigate how this feature is generated and whether the correlation makes sense.
  • Be very careful every time you look at the test split. If you use the test split in any way other than to report a model’s final performance, whether to come up with ideas for new features or to tune hyperparameters, you risk leaking information from the future into your training process.
  • In theory, if a feature doesn’t help a model make good predictions, regularization techniques like L1 regularization should reduce that feature’s weight to 0. However, in practice, it might help models learn faster if the features that are no longer useful (and even possibly harmful) are removed, prioritizing good features.
  • Feature Importance There are many different methods for measuring a feature’s importance. If you use a classical ML algorithm like boosted gradient trees, the easiest way to measure the importance of your features is to use built-in feature importance functions implemented by XGBoost.17 For more model-agnostic methods, you might want to look into SHAP (SHapley Additive exPlanations).18 InterpretML is a great open source package that leverages feature importance to help you understand how your model makes predictions. The exact algorithm for feature importance measurement is complex, but intuitively, a feature’s importance to a model is measured by how much that model’s performance deteriorates if that feature or a set of features containing that feature is removed from the model. SHAP is great because it not only measures a feature’s importance to an entire model, it also measures each feature’s contribution to a model’s specific prediction.
  • two aspects you might want to consider with regards to generalization: feature coverage and distribution of feature values. Coverage is the percentage of the samples that has values for this feature in the data—so the fewer values that are missing, the higher the coverage. A rough rule of thumb is that if this feature appears in a very small percentage of your data, it’s not going to be very generalizable.
  • Here is a summary of best practices for feature engineering: Split data by time into train/valid/test splits instead of doing it randomly. If you oversample your data, do it after splitting. Scale and normalize your data after splitting to avoid data leakage. Use statistics from only the train split, instead of the entire data, to scale your features and handle missing values. Understand how your data is generated, collected, and processed. Involve domain experts if possible. Keep track of your data’s lineage. Understand feature importance to your model. Use features that generalize well. Remove no longer useful features from your models.
  • However, even though deep learning is finding more use cases in production, classical ML algorithms are not going away. Many recommender systems still rely on collaborative filtering and matrix factorization. Tree-based algorithms, including gradient-boosted trees, still power many classification tasks with strict latency requirements.
  • if your boss tells you to build a system to detect toxic tweets, you know that this is a text classification problem—given a piece of text, classify whether it’s toxic or not—and common models for text classification include naive Bayes, logistic regression, recurrent neural networks, and transformer-based models such as BERT, GPT, and their variants. If your client wants you to build a system to detect fraudulent transactions, you know that this is the classic abnormality detection problem—fraudulent transactions are abnormalities that you want to detect—and common algorithms for this problem are many, including k-nearest neighbors, isolation forest, clustering, and neural networks.
  • When considering what model to use, it’s important to consider not only the model’s performance, measured by metrics such as accuracy, F1 score, and log loss, but also its other properties, such as how much data, compute, and time it needs to train, what’s its inference latency, and interpretability. For example, a simple logistic regression model might have lower accuracy than a complex neural network, but it requires less labeled data to start, it’s much faster to train, it’s much easier to deploy, and it’s also much easier to explain why it’s making certain predictions.
  • To keep up to date with so many new ML techniques and models, I find it helpful to monitor trends at major ML conferences such as NeurIPS, ICLR, and ICML, as well as following researchers whose work has a high signal-to-noise ratio on Twitter. Six
  • Simplicity serves three purposes. First, simpler models are easier to deploy, and deploying your model early allows you to validate that your prediction pipeline is consistent with your training pipeline. Second, starting with something simple and adding more complex components step-by-step makes it easier to understand your model and debug it. Third, the simplest model serves as a baseline to which you can compare your more complex models.
  • Because the performance of a model architecture depends heavily on the context it’s evaluated in—e.g., the task, the training data, the test data, the hyperparameters, etc.—it’s extremely difficult to make claims that a model architecture is better than another architecture. The claim might be true in a context, but unlikely true for all possible contexts.
  • the simple neural network can update itself with each incoming example, whereas the collaborative filtering has to look at all the data to update its underlying matrix.
  • Ensembling methods are less favored in production because ensembles are more complex to deploy and harder to maintain. However, they are still common for tasks where a small performance boost can lead to a huge financial gain, such as predicting click-through rate for ads.
  • Imagine you have three email spam classifiers, each with an accuracy of 70%. Assuming that each classifier has an equal probability of making a correct prediction for each email, and that these three classifiers are not correlated, we’ll show that by taking the majority vote of these three classifiers, we can get an accuracy of 78.4%.
  • Bagging Bagging, shortened from bootstrap aggregating, is designed to improve both the training stability and accuracy of ML algorithms.4 It reduces variance and helps to avoid overfitting. Given a dataset, instead of training one classifier on the entire dataset, you sample with replacement to create different datasets, called bootstraps, and train a classification or regression model on each of these bootstraps. Sampling with replacement ensures that each bootstrap is created independently from its peers.
  • A random forest is an example of bagging. A random forest is a collection of decision trees constructed by both bagging and feature randomness, where each tree can pick only from a random subset of features to use. Boosting
  • Boosting Boosting is a family of iterative ensemble algorithms that convert weak learners to strong ones. Each learner in this ensemble is trained on the same set of samples, but the samples are weighted differently among iterations. As a result, future weak learners focus more on the examples that previous weak learners misclassified.
  • Bagging generally improves unstable methods, such as neural networks, classification and regression trees, and subset selection in linear regression. However, it can mildly degrade the performance of stable methods such as k-nearest neighbors.5 A random forest is an example of bagging. A random forest is a collection of decision trees constructed by both bagging and feature randomness, where each tree can pick only from a random subset of features to use.
  • An example of a boosting algorithm is a gradient boosting machine (GBM), which produces a prediction model typically from weak decision trees. It builds the model in a stage-wise fashion like other boosting methods do, and it generalizes them by allowing optimization of an arbitrary differentiable loss function. XGBoost, a variant of GBM, used to be the algorithm of choice for many winning teams of ML competitions.6 It’s been used in a wide range of tasks from classification, ranking, to the discovery of the Higgs Boson.7 However, many teams have been opting for LightGBM, a distributed gradient boosting framework that allows parallel learning, which generally allows faster training on large datasets.
  • Stacking Stacking means that you train base learners from the training data then create a meta-learner that combines the outputs of the base learners to output final predictions, as shown in Figure 6-5. The meta-learner can be as simple as a heuristic: you take the majority vote (for classification tasks) or the average vote (for regression tasks) from all base learners. It can be another model, such as a logistic regression model or a linear regression model.
  • An artifact is a file generated during an experiment—examples of artifacts can be files that show the loss curve, evaluation loss graph, logs, or intermediate results of a model throughout a training process.
  • Experiment tracking A large part of training an ML model is babysitting the learning processes. Many problems can arise during the training process, including loss not decreasing, overfitting, underfitting, fluctuating weight values, dead neurons, and running out of memory. It’s important to track what’s going on during training not only to detect and address these issues but also to evaluate whether your model is learning anything useful.
  • Following is just a short list of things you might want to consider tracking for each experiment during its training process: The loss curve corresponding to the train split and each of the eval splits. The model performance metrics that you care about on all nontest splits, such as accuracy, F1, perplexity. The log of corresponding sample, prediction, and ground truth label. This comes in handy for ad hoc analytics and sanity check. The speed of your model, evaluated by the number of steps per second or, if your data is text, the number of tokens processed per second. System performance metrics such as memory usage and CPU/GPU utilization. They’re important to identify bottlenecks and avoid wasting system resources. The values over time of any parameter and hyperparameter whose changes can affect your model’s performance, such as the learning rate if you use a learning rate schedule; gradient norms (both globally and per layer), especially if you’re clipping your gradient norms; and weight norm, especially if you’re doing weight decay.
  • Here are some of the things that might cause an ML model to fail: Theoretical constraints
  • Poor implementation of model
  • Poor choice of hyperparameters
  • Data problems
  • Poor choice of features
  • The following are three of them. Readers interested in learning more might want to check out Andrej Karpathy’s awesome post “A Recipe for Training Neural Networks”. Start simple and gradually add more components Start with the simplest model and then slowly add more components to see if it helps or hurts the performance. For example, if you want to build a recurrent neural network (RNN), start with just one level of RNN cell before stacking multiple together or adding more regularization. If you want to use a BERT-like model (Devlin et al. 2018), which uses both a masked language model (MLM) and next sentence prediction (NSP) loss, you might want to use only the MLM loss before adding NSP loss.
  • Overfit a single batch After you have a simple implementation of your model, try to overfit a small amount of training data and run evaluation on the same data to make sure that it gets to the smallest possible loss. If it’s for image recognition, overfit on 10 images and see if you can get the accuracy to be 100%, or if it’s for machine translation, overfit on 100 sentence pairs and see if you can get to a BLEU score of near 100.
  • Set a random seed
  • It’s common to train a model using data that doesn’t fit into memory. It’s especially common when dealing with medical data such as CT scans or genome sequences. It can also happen with text data if you work for teams that train large language models (cue OpenAI, Google, NVIDIA, Cohere).
  • When a sample of your data is large, e.g., one machine can handle a few samples at a time, you might only be able to work with a small batch size, which leads to instability for gradient descent-based optimization.
  • Data parallelism It’s now the norm to train ML models on multiple machines. The most common parallelization method supported by modern ML frameworks is data parallelism: you split your data on multiple machines, train your model on all of them, and accumulate gradients. This gives rise to a couple of issues. A challenging problem is how to accurately and effectively accumulate gradients from different machines. As each machine produces its own gradient, if your model waits for all of them to finish a run—synchronous stochastic gradient descent (SGD)—stragglers will cause the entire system to slow down, wasting time and resources.
  • Despite knowing its importance, many still ignore systematic approaches to hyperparameter tuning in favor of a manual, gut-feeling approach. The most popular is arguably graduate student descent (GSD), a technique in which a graduate student fiddles around with the hyperparameters until the model works.
  • Popular methods for hyperparameter tuning include random search,26 grid search, and Bayesian optimization.27 The book AutoML: Methods, Systems, Challenges by the AutoML group at the University of Freiburg dedicates its first chapter (which you can read online for free) to hyperparameter optimization. When tuning hyperparameters, keep in mind that a model’s performance might be more sensitive to the change in one hyperparameter than another, and therefore sensitive hyperparameters should be more carefully tuned.
  • Baselines Someone once told me that her new generative model achieved the FID score of 10.3 on ImageNet.38 I had no idea what this number meant or whether her model would be useful for my problem. Another time, I helped a company implement a classification model where the positive class appears 90% of the time. An ML engineer on the team told me, all excited, that their initial model achieved an F1 score of 0.90. I asked him how it was compared to random. He had no idea. It turned out that because for his task the POSITIVE class accounts for 90% of the labels, if his model randomly outputs the positive class 90% of the time, its F1 score would also be around 0.90.39
  • Random baseline
  • Simple heuristic
  • Human baseline
  • Existing solutions
  • When evaluating a model, it’s important to differentiate between “a good system” and “a useful system.” A good system isn’t necessarily useful, and a bad system isn’t necessarily useless. A self-driving vehicle might be good if it’s a significant improvement from previous self-driving systems, but it might not be useful if it doesn’t perform at least as well as human drivers. In some cases, even if an ML system drives better than an average human, people might still not trust it, which renders it not useful. On the other hand, a system that predicts what word a user will type next on their phone might be considered bad if it’s much worse than a native speaker. However, it might still be useful if its predictions can help users type faster some of the time. Evaluation
  • If you want to deploy a model for your friends to play with, all you have to do is to wrap your predict function in a POST request endpoint using Flask or FastAPI, put the dependencies this predict function needs to run in a container,2 and push your model and its associated container to a cloud service like AWS or GCP to expose the endpoint: # Example of how to use FastAPI to turn your predict function # into a POST endpoint @app.route(‘/predict’, methods=[‘POST’]) def predict(): X = request.get_json()[‘X’] y = MODEL.predict(X).tolist() return json.dumps({‘y’: y}), 200 You can use this exposed endpoint for downstream applications: e.g., when an application receives a prediction request from a user, this request is sent to the exposed endpoint, which returns a prediction. If
  • The phenomenon in which a software program degrades over time even if nothing seems to have changed is known as “software rot” or “bit rot.” ML systems aren’t immune to it. On top of that, ML systems suffer from what are known as data distribution shifts, when the data distribution your model encounters in production is different from the data distribution it was trained on.10 Therefore, an ML model tends to perform best right after training and to degrade over time.
  • there are three main modes of prediction that I hope you’ll remember: Batch prediction, which uses only batch features. Online prediction that uses only batch features (e.g., precomputed embeddings). Online prediction that uses both batch features and streaming features. This is also known as streaming prediction.
  • Online prediction is when predictions are generated and returned as soon as requests for these predictions arrive. For example, you enter an English sentence into Google Translate and get back its French translation immediately. Online prediction is also known as on-demand prediction. Traditionally, when doing online prediction, requests are sent to the prediction service via RESTful APIs (e.g., HTTP requests—see “Data Passing Through Services”). When prediction requests are sent via HTTP requests, online prediction is also known as synchronous prediction: predictions are generated in synchronization with requests.
  • If the model you want to deploy takes too long to generate predictions, there are three main approaches to reduce its inference latency: make it do inference faster, make the model smaller, or make the hardware it’s deployed on run faster.
  • While there are many new techniques being developed, the four types of techniques that you might come across the most often are low-rank optimization, knowledge distillation, pruning, and quantization.
  • deploying a model isn’t the end of the process. A model’s performance degrades over time in production. Once a model has been deployed, we still have to continually monitor its performance to detect issues as well as deploy updates to fix these issues.
  • The assumption is that the unseen data comes from a stationary distribution that is the same as the training data distribution. If the unseen data comes from a different distribution, the model might not generalize well.
  • Concept drift, also known as posterior shift, is when the input distribution remains the same but the conditional distribution of the output given an input changes. You can think of this as “same input, different output.” Consider you’re in charge of a model that predicts the price of a house based on its features. Before COVID-19, a three-bedroom apartment in San Francisco could cost 1,500,000. So even though the distribution of house features remains the same, the conditional distribution of the price of a house given its features has changed.
  • mean/median/variance are useful summaries. If those metrics differ significantly, the inference distribution might have shifted from the training distribution. However, if those metrics are similar, there’s no guarantee that there’s no shift. A more sophisticated solution is to use a two-sample hypothesis test, shortened as two-sample test. It’s
  • A basic two-sample test is the Kolmogorov–Smirnov test, also known as the K-S or KS test.32 It’s a nonparametric statistical test, which means it doesn’t require any parameters of the underlying distribution
  • However, one major drawback of the KS test is that it can only be used for one-dimensional data.
  • Alibi Detect is a great open source package with the implementations of many drift detection algorithms, as shown in Figure 8-2. Because two-sample tests often work better on low-dimensional data than on high-dimensional data, it’s highly recommended that you reduce the dimensionality of your data before performing a two-sample test on it.
  • many companies assume that data shifts are inevitable, so they periodically retrain their models—once a month, once a week, or once a day—regardless of the extent of the shift.
  • To make a model work with a new distribution in production, there are three main approaches. The first is the approach that currently dominates research: train models using massive datasets. The hope here is that if the training dataset is large enough, the model will be able to learn such a comprehensive distribution that whatever data points the model will encounter in production will likely come from this distribution. The second approach, less popular in research, is to adapt a trained model to a target distribution without requiring new labels. Zhang et al. (2013) used causal interpretations together with kernel embedding of conditional and marginal distributions to correct models’ predictions for both covariate shifts and label shifts without using labels from the target distribution.
  • The third approach is what is usually done in the industry today: retrain your model using the labeled data from the target distribution. However, retraining your model is not so straightforward. Retraining can mean retraining your model from scratch on both the old and new data or continuing training the existing model on new data. The latter approach is also called fine-tuning.
  • Monitoring is all about metrics. Because ML systems are software systems, the first class of metrics you’d need to monitor are the operational metrics. These metrics are designed to convey the health of your systems. They are generally divided into three levels: the network the system is run on, the machine the system is run on, and the application that the system runs. Examples of these metrics are latency; throughput; the number of prediction requests your model receives in the last minute, hour, day; the percentage of requests that return with a 2xx code; CPU/GPU utilization; memory utilization; etc. No matter how good your ML model is, if the system is down, you’re not going to benefit from it.
  • ML-Specific Metrics Within ML-specific metrics, there are generally four artifacts to monitor: a model’s accuracy-related metrics, predictions, features, and raw inputs. These are artifacts generated at four different stages of an ML system pipeline, as shown in Figure 8-5. The deeper into the pipeline an artifact is, the more transformations it has gone through, which makes a change in that artifact more likely to be caused by errors in one of those transformations. However, the more transformations an artifact has gone through, the more structured it’s become and the closer it is to the metrics you actually care about, which makes it easier to monitor.
  • Monitoring features
  • Monitoring predictions
  • Monitoring accuracy-related metrics
  • In theory, a small distribution shift can cause catastrophic failure, but in practice, an individual feature’s minor changes might not harm the model’s performance at all. Feature distributions shift all the time, and most of these changes are benign.48 If you want to be alerted whenever a feature seems to have drifted, you might soon be overwhelmed by alerts and realize that most of these alerts are false positives. This can cause a phenomenon called “alert fatigue” where the monitoring team stops paying attention to the alerts because they are so frequent. The problem of feature monitoring becomes the problem of trying to decide which feature shifts are critical and which are not.
  • Analyzing billions of logged events manually is futile, so many companies use ML to analyze logs. An example use case of ML in log analysis is anomaly detection: to detect abnormal events in your system.
  • Another use case of ML in log analysis is that when a service fails, it might be helpful to know the probability of related services being affected. This could be especially useful when the system is under cyberattack.
  • Alert fatigue is a real phenomenon, as discussed previously in this chapter. Alert fatigue can be demoralizing—nobody likes to be awakened in the middle of the night for something outside of their responsibilities. It’s also dangerous—being exposed to trivial alerts can desensitize people to critical alerts. It’s important to set meaningful conditions so that only critical alerts are sent out.
  • Observability is a term used to address this challenge. It’s a concept drawn from control theory, and it refers to bringing “better visibility into understanding the complex behavior of software using [outputs] collected from the system at run time.”54
  • Observability is about instrumenting your system in a way to ensure that sufficient information about a system’s runtime is collected and analyzed.
  • We discussed three major causes of ML-specific failures: production data differing from training data, edge cases, and degenerate feedback loops. The first two causes are related to data, whereas the last cause is related to system design because it happens when the system’s outputs influence the same system’s input.
  • We looked into three types of shifts: covariate shift, label shift, and concept drift. Even though studying distribution shifts is a growing subfield of ML research, the research community hasn’t yet found a standard narrative.
  • Monitoring is all about metrics. We discussed different metrics we need to monitor: operational metrics—the metrics that should be monitored with any software systems such as latency, throughput, and CPU utilization—and ML-specific metrics. Monitoring can be applied to accuracy-related metrics, predictions, features, and/or raw inputs.
  • if your model is a neural network, learning with every incoming sample makes it susceptible to catastrophic forgetting. Catastrophic forgetting refers to the tendency of a neural network to completely and abruptly forget previously learned information upon learning new information.
  • you shouldn’t make changes to the existing model directly. Instead, you create a replica of the existing model and update this replica on new data, and only replace the existing model with the updated replica if the updated replica proves to be better. The existing model is called the champion model, and the updated replica, the challenger.
  • Most companies do stateless retraining—the model is trained from scratch each time. Continual learning means also allowing stateful training—the model continues training on new data.2 Stateful training is also known as fine-tuning or incremental learning.
  • In the stateful training paradigm, each model update is trained using only the fresh data, so a data sample is used only once for training, as shown in Figure 9-2. This means that it’s possible to train your model without having to store data in permanent storage, which helps eliminate many concerns about data privacy. However, this is overlooked because today’s let’s-keep-track-of-everything practice still makes many companies reluctant to throw away data.
  • The first use case of continual learning is to combat data distribution shifts, especially when the shifts happen suddenly. Imagine you’re building a model to determine the prices for a ride-sharing service like Lyft.6 Historically, the ride demand on a Thursday evening in this particular neighborhood is slow, so the model predicts low ride prices,
  • Another use case of continual learning is to adapt to rare events. Imagine you work for an ecommerce website like Amazon. Black Friday is an important shopping event that happens only once a year. There’s no way you will be able to gather enough historical data for your model to be able to make accurate predictions on how your customers will behave throughout Black Friday this year. To improve performance, your model should learn throughout the day with fresh data.
  • A huge challenge for ML production today that continual learning can help overcome is the continuous cold start problem. The cold start problem arises when your model has to make predictions for a new user without any historical data.
  • Fresh data access challenge The first challenge is the challenge to get fresh data. If you want to update your model every hour, you need new data every hour. Currently, many companies pull new training data from their data warehouses. The speed at which you can pull data from your data warehouses depends on the speed at which this data is deposited into your data warehouses. The speed can be slow, especially if data comes from multiple sources. An alternative is to allow pull data before it’s deposited into data warehouses, e.g., directly from real-time transports such as Kafka and Kinesis that transport data from applications to data warehouses,12
  • The best candidates for continual learning are tasks where you can get natural labels with short feedback loops. Examples of these tasks are dynamic pricing (based on estimated demand and availability), estimating time of arrival, stock price prediction, ads click-through prediction, and recommender systems for online content like tweets, songs, short videos, articles, etc.
  • If you run an ecommerce website, your application might register that at 10:33 p.m., user A clicks on the product with the ID of 32345. Your system needs to look back into the logs to see if this product ID was ever recommended to this user, and if yes, then what query prompted this recommendation, so that your system can match this query to this recommendation and label this recommendation as a good recommendation, as shown in Figure 9-4. Figure 9-4. A simplification of the process of extracting labels from user feedback The process of looking back into the logs to extract labels is called label computation. It can be quite costly if the number of logs is large. Label computation can be done with batch processing: e.g., waiting for logs to be deposited into data warehouses first before running a batch job to extract all labels from logs at once.
  • Evaluation challenge
  • The risks for catastrophic failures amplify with continual learning. First, the more frequently you update your models, the more opportunities there are for updates to fail. Second, continual learning makes your models more susceptible to coordinated manipulation and adversarial attack. Because your models learn online from real-world data, it makes it easier for users to input malicious data to trick models into learning wrong things.
  • Algorithm challenge
  • To illustrate this point, consider two different models: a neural network and a matrix-based model, such as a collaborative filtering model. The collaborative filtering model uses a user-item matrix and a dimension reduction technique. You can update the neural network model with a data batch of any size. You can even perform the update step with just one data sample. However, if you want to update the collaborative filtering model, you first need to use the entire dataset to build the user-item matrix before performing dimensionality reduction on it. Of course, you can apply dimensionality reduction to your matrix each time you update the matrix with a new data sample, but if your matrix is large, the dimensionality reduction step would be too slow and expensive to perform frequently. Therefore, this model is less suitable for learning with a partial dataset than the preceding neural network model.
  • Four Stages of Continual Learning
  • Stage 1: Manual, stateless retraining
  • Stage 2: Automated retraining
  • Stage 3: Automated, stateful training
  • How Often to Update Your Models
  • To understand why offline evaluation isn’t enough, let’s go over two major test types for offline evaluation: test splits and backtests.
  • There are techniques to help you evaluate your models in production (mostly) safely. In this section, we’ll cover the following techniques: shadow deployment, A/B testing, canary analysis, interleaving experiments, and bandits.
  • A/B Testing A/B testing is a way to compare two variants of an object, typically by testing responses to these two variants, and determining which of the two variants is more effective.
  • However, there are cases where one model’s predictions might affect another model’s predictions—e.g., in ride-sharing’s dynamic pricing, a model’s predicted prices might influence the number of available drivers and riders, which, in turn, influence the other model’s predictions. In those cases, you might have to run your variants alternatively, e.g., serve model A one day and then serve model B the next day.
  • First, A/B testing consists of a randomized experiment: the traffic routed to each model has to be truly random.
  • For readers interested in learning more about A/B testing and other statistical concepts important in ML, I recommend Ron Kohav’s book Trustworthy Online Controlled Experiments (A Practical Guide to A/B Testing) (Cambridge University Press) and Michael Barber’s great introduction to statistics for data science (much shorter).
  • Canary Release Canary release is a technique to reduce the risk of introducing a new software version in production by slowly rolling out the change to a small subset of users before rolling it out to the entire infrastructure and making it available to everybody.
  • You don’t know which slot machine gives the highest payout. You can experiment over time to find out which slot machine is the best while maximizing your payout. Multi-armed bandits are algorithms that allow you to balance between exploitation (choosing the slot machine that has paid the most in the past) and exploration (choosing other slot machines that may pay off even more).
  • However, the compute layer doesn’t always use threads or cores as compute units. There are compute layers that abstract away the notions of cores and use other units of computation. For example, computation engines like Spark and Ray use “job” as their unit, and Kubernetes uses “pod,” a wrapper around containers, as its smallest deployable unit. While you can have multiple containers in a pod, you can’t independently start or stop different containers in the same pod.
  • Dev Environment Setup
  • Docker Compose is a lightweight container orchestrator that can manage containers on a single host. However, each of your containers might run on its own host, and this is where Docker Compose is at its limits. Kubernetes (K8s) is a tool for exactly that. K8s creates a network for containers to communicate and share resources. It can help you spin up containers on more instances when you need more compute/memory as well as shutting down containers when you no longer need them, and it helps maintain high availability for your system. K8s was one of the fastest-growing technologies in the 2010s.
  • Cron, Schedulers, and Orchestrators There are two key characteristics of ML workflows that influence their resource management: repetitiveness and dependencies.
  • If schedulers are concerned with when to run jobs and what resources are needed to run those jobs, orchestrators are concerned with where to get those resources. Schedulers deal with job-type abstractions such as DAGs, priority queues, user-level quotas (i.e., the maximum number of instances a user can use at a given time), etc. Orchestrators deal with lower-level abstractions like machines, instances, clusters, service-level grouping, replication, etc. If the orchestrator notices that there are more jobs than the pool of available instances, it can increase the number of instances in the available instance pool.
  • Because ML platforms are relatively new, what exactly constitutes an ML platform varies from company to company. Even within the same company, it’s an ongoing discussion. Here, I’ll focus on the components that I most often see in ML platforms, which include model development, model store, and feature store.
  • While it’s usually straightforward to do online prediction at a smaller scale with most deployment services, doing batch prediction is usually trickier.29 Some tools allow you to batch requests together for online prediction, which is different from batch prediction. Many companies have separate deployment pipelines for online prediction and batch prediction. For example, they might use Seldon for online prediction but leverage Databricks for batch prediction. An open problem with model deployment is how to ensure the quality of a model before it’s deployed.
  • At its core, there are three main problems that a feature store can help address: feature management, feature transformation, and feature consistency. A feature store solution might address one or a combination of these problems:
  • “If it’s something we want to be really good at, we’ll manage that in-house. If not, we’ll use a vendor.”

table file.inlinks, file.outlinks from [[]] and !outgoing([[]])  AND -"Changelog"