133 lines
6.2 KiB
Java
133 lines
6.2 KiB
Java
/*
|
|
* Copyright (C) 2024 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package android.adservices.ondevicepersonalization;
|
|
|
|
import android.adservices.ondevicepersonalization.aidl.IDataAccessService;
|
|
import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService;
|
|
import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback;
|
|
import android.annotation.CallbackExecutor;
|
|
import android.annotation.FlaggedApi;
|
|
import android.annotation.NonNull;
|
|
import android.annotation.WorkerThread;
|
|
import android.os.Bundle;
|
|
import android.os.OutcomeReceiver;
|
|
import android.os.RemoteException;
|
|
|
|
import com.android.adservices.ondevicepersonalization.flags.Flags;
|
|
import com.android.ondevicepersonalization.internal.util.LoggerFactory;
|
|
|
|
import java.util.Objects;
|
|
import java.util.concurrent.Executor;
|
|
|
|
/**
|
|
* Handles model inference and only support TFLite model inference now. See {@link
|
|
* IsolatedService#getModelManager}.
|
|
*/
|
|
@FlaggedApi(Flags.FLAG_ON_DEVICE_PERSONALIZATION_APIS_ENABLED)
|
|
public class ModelManager {
|
|
private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger();
|
|
private static final String TAG = ModelManager.class.getSimpleName();
|
|
@NonNull private final IDataAccessService mDataService;
|
|
|
|
@NonNull private final IIsolatedModelService mModelService;
|
|
|
|
/** @hide */
|
|
public ModelManager(
|
|
@NonNull IDataAccessService dataService, @NonNull IIsolatedModelService modelService) {
|
|
mDataService = dataService;
|
|
mModelService = modelService;
|
|
}
|
|
|
|
/**
|
|
* Run a single model inference. Only supports TFLite model inference now.
|
|
*
|
|
* @param input contains all the information needed for a run of model inference.
|
|
* @param executor the {@link Executor} on which to invoke the callback.
|
|
* @param receiver this returns a {@link InferenceOutput} which contains model inference result
|
|
* or {@link Exception} on failure.
|
|
*/
|
|
@WorkerThread
|
|
public void run(
|
|
@NonNull InferenceInput input,
|
|
@NonNull @CallbackExecutor Executor executor,
|
|
@NonNull OutcomeReceiver<InferenceOutput, Exception> receiver) {
|
|
final long startTimeMillis = System.currentTimeMillis();
|
|
Objects.requireNonNull(input);
|
|
Bundle bundle = new Bundle();
|
|
bundle.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, mDataService.asBinder());
|
|
bundle.putParcelable(Constants.EXTRA_INFERENCE_INPUT, new InferenceInputParcel(input));
|
|
try {
|
|
mModelService.runInference(
|
|
bundle,
|
|
new IIsolatedModelServiceCallback.Stub() {
|
|
@Override
|
|
public void onSuccess(Bundle result) {
|
|
executor.execute(
|
|
() -> {
|
|
int responseCode = Constants.STATUS_SUCCESS;
|
|
long endTimeMillis = System.currentTimeMillis();
|
|
try {
|
|
InferenceOutputParcel outputParcel =
|
|
Objects.requireNonNull(
|
|
result.getParcelable(
|
|
Constants.EXTRA_RESULT,
|
|
InferenceOutputParcel.class));
|
|
InferenceOutput output =
|
|
new InferenceOutput(outputParcel.getData());
|
|
endTimeMillis = System.currentTimeMillis();
|
|
receiver.onResult(output);
|
|
} catch (Exception e) {
|
|
endTimeMillis = System.currentTimeMillis();
|
|
responseCode = Constants.STATUS_INTERNAL_ERROR;
|
|
receiver.onError(e);
|
|
} finally {
|
|
logApiCallStats(
|
|
Constants.API_NAME_MODEL_MANAGER_RUN,
|
|
endTimeMillis - startTimeMillis,
|
|
responseCode);
|
|
}
|
|
});
|
|
}
|
|
|
|
@Override
|
|
public void onError(int errorCode) {
|
|
executor.execute(
|
|
() -> {
|
|
long endTimeMillis = System.currentTimeMillis();
|
|
receiver.onError(
|
|
new IllegalStateException("Error: " + errorCode));
|
|
logApiCallStats(
|
|
Constants.API_NAME_MODEL_MANAGER_RUN,
|
|
endTimeMillis - startTimeMillis,
|
|
Constants.STATUS_INTERNAL_ERROR);
|
|
});
|
|
}
|
|
});
|
|
} catch (RemoteException e) {
|
|
receiver.onError(new IllegalStateException(e));
|
|
}
|
|
}
|
|
|
|
private void logApiCallStats(int apiName, long duration, int responseCode) {
|
|
try {
|
|
mDataService.logApiCallStats(apiName, duration, responseCode);
|
|
} catch (Exception e) {
|
|
sLogger.d(e, TAG + ": failed to log metrics");
|
|
}
|
|
}
|
|
}
|